# Analysis
Load results files and run analysis on them

# Functions

## Plotting

In [None]:
import matplotlib.pyplot as plt
from tigramite import plotting as tp
import numpy as np

def plot_links(results,
               save_name = None,
               figsize = (16, 16)
              ):
    """
    This function is copied from the basic tutorial, but it may not be
    generalizable: It had issues with output from pcmciplus
    """

    link_matrix = results["link_matrix"]
    var_names = np.array(results["var_names"])
#     val_matrix = results['val_matrix']
    
    links = results["links"]
    linked_variables = set() # Avoids duplicates and sorts when list
    for child, parents_list in links.items():
        if len(parents_list) > 0:
            linked_variables.add(child)
            for parent in parents_list:
                linked_variables.add(parent[0])
    if len(linked_variables) != 0:
        linked_variables = list(linked_variables)
    else:
        linked_variables = range(len(links))
    
    tp.plot_graph(
        figsize = figsize,
#         val_matrix = val_matrix[linked_variables][:,linked_variables],
        link_matrix = link_matrix[linked_variables][:,linked_variables],
        var_names = var_names[linked_variables],
        link_colorbar_label = 'cross-MCI',
        node_colorbar_label = 'auto-MCI',
        arrow_linewidth = 5,
        node_size = 0.15,
        save_name = save_name,
        show_colorbar = False
    ); plt.show()

# Main program

## Specifications

In [None]:
import sys
import getopt
import yaml
import time
import datetime
from pathlib import Path

import utils.utils as utils
from utils.constants import SPCAM_Vars, DATA_FOLDER, ANCIL_FILE
# from utils.constants import PLOT_FILE_PATTERN
from utils.constants import experiment

In [None]:
argv           = sys.argv[1:]
# argv           = ['-c', 'cfg_pipeline.yml']
# argv           = ['-c', 'cfg_lon120.yml']
try:
    opts, args = getopt.getopt(argv,"hc:a",["cfg_file=","add="])
except getopt.GetoptError:
    print ('pipeline.py -c [cfg_file] -a [add]')
    sys.exit(2)
for opt, arg in opts:
    if opt == '-h':
        print ('pipeline.py -c [cfg_file]')
        sys.exit()
    elif opt in ("-c", "--cfg_file"):
        yml_cfgFilenm = arg
    elif opt in ("-a", "--add"):
        pass

# YAML config file
yml_cfgFile       = open(yml_cfgFilenm)
yml_cfg           = yaml.load(yml_cfgFile, Loader=yaml.FullLoader)

# Load specifications
# analysis = yml_cfg['analysis']
analysis = "single" # Only single is supported right now
spcam_parents     = yml_cfg['spcam_parents']
spcam_children    = yml_cfg['spcam_children']
pc_alphas         = yml_cfg['pc_alphas']
region            = yml_cfg['region']
lim_levels        = yml_cfg['lim_levels']
target_levels     = yml_cfg['target_levels']
verbosity         = yml_cfg['verbosity']
output_folder     = yml_cfg['output_folder']
plots_folder = yml_cfg['plots_folder']
output_file_pattern = yml_cfg['output_file_pattern'][analysis]
plot_file_pattern = yml_cfg['plot_file_pattern']
overwrite         = False

Path(plots_folder).mkdir(parents=True, exist_ok=True)

In [None]:
## Region / Gridpoints
if region is False:
    region     = [ [-90,90] , [0,-.5] ] # All
gridpoints = utils.get_gridpoints(region)

## Children levels (parents includes all)
if lim_levels is not False and target_levels is False:
    target_levels = utils.get_levels(lim_levels)

In [None]:
## Model's grid
levels, latitudes, longitudes = utils.read_ancilaries(Path(DATA_FOLDER, ANCIL_FILE))

## Latitude / Longitude indexes
idx_lats = [utils.find_closest_value(latitudes, gridpoint[0])      for gridpoint in gridpoints]
idx_lons = [utils.find_closest_longitude(longitudes, gridpoint[1]) for gridpoint in gridpoints]

## Level indexes (children & parents)
parents_idx_levs = [[round(lev, 2), i] for i, lev in enumerate(levels)] # All
if target_levels is not False:
    children_idx_levs = [[lev, utils.find_closest_value(levels, lev)] for lev in target_levels]
else:
    children_idx_levs = parents_idx_levs

In [None]:
## Variables
spcam_vars_include = spcam_parents + spcam_children
var_list = [var for var in SPCAM_Vars if var.label in spcam_vars_include]
var_parents = [var for var in var_list if var.type == "in"]
var_children = [var for var in var_list if var.type == "out"]  

## Aggregate results

In [None]:
def get_parents_from_links(links):
    linked_variables = set() # Avoids duplicates and sorts when list
    for parents_list in links.values():
        if len(parents_list) > 0:
            linked_variables.add(child)
            for parent in parents_list:
                linked_variables.add(parent[0])
    return [(i in linked_variables)  for i in range(len(links))]

In [None]:
import numpy as np

KEY_PATTERN = "{var_name}-{level}"

aggregated_results = dict()
for child in var_children:
    print(f"Variable: {child.name}")
    for i_grid, (lat, lon) in enumerate(gridpoints):

        if child.dimensions == 2:
            child_levels = [[levels[-1],0]]
            key = child.name
        elif child.dimensions == 3:
            child_levels = children_idx_levs
        for level in child_levels:
            if child.dimensions == 3:
                key = KEY_PATTERN.format(
                        var_name = child.name,
                        level = round(level[0], 2)
                )
            aggregated_pc_alpha = aggregated_results.get(key, dict())
            results_file = utils.generate_results_filename(
                    child, level[1], lat, lon, experiment, output_file_pattern, output_folder)
            if not results_file.is_file():
                print(f"File {results_file} not found, skipping.")
                continue
            results = utils.load_results(results_file)
            for pc_alpha, alpha_result in results.items():
                if len(alpha_result) > 0:
                    links = alpha_result["links"]
                    parents = get_parents_from_links(links)
                    aggregated = aggregated_pc_alpha.get(pc_alpha, list())
                    aggregated.append(parents)
                    aggregated_pc_alpha[pc_alpha] = aggregated
                    var_names = alpha_result["var_names"] # TODO Very uncomfortable way to obtain this data. Should be metadata
            aggregated_results[key] = aggregated_pc_alpha

In [None]:
thresholds = [.5, .6, .7, .8, .9, 1]
var_names = np.array(var_names)

for child, result in aggregated_results.items():
    print(child)
    for pc_alpha, parents_matrix in result.items():
        parents_matrix = np.array(parents_matrix)
        parents_percent = parents_matrix.sum(axis = 0) / parents_matrix.shape[0]
        print(pc_alpha)
        for threshold in thresholds:
            parents_filtered = parents_percent >= threshold
            parents = [i for i in range(len(parents_filtered)) if parents_filtered[i]]
#             print(parents_filtered)
#             print(parents)
            print(f"Threshold {threshold}:\t{var_names[parents]}")


## Plots

In [None]:
# Load data
len_grid = len(gridpoints)
t_start = time.time()
for i_grid, (lat, lon) in enumerate(gridpoints):
    idx_lat = idx_lats[i_grid]
    idx_lon = idx_lons[i_grid]
    
    t_start_gridpoint = time.time()
    print(f"Gridpoint {i_grid+1}/{len_grid}:"
          + f"lat={lat} ({idx_lat}), lon={lon} ({idx_lon})")
    for child in var_children:
        print(f"Variable: {child.name}")
        if child.dimensions == 2:
            child_levels = [[levels[-1],0]]
        elif child.dimensions == 3:
            child_levels = children_idx_levs
        for level in child_levels:
            results = utils.load_results(child, level[1], lat, lon, experiment,
                                         output_file_pattern, output_folder)

            print(f"Plotting links for {child.name} at level {level[1]+1}")
            t_before_plot_linkgs = time.time()
            # Plotting
            for pc_alpha, alpha_result in results.items():
                print(f"pc_alpha = {pc_alpha}")
                plot_file = Path(plots_folder, plot_file_pattern.format(
                        var_name = child.name,
                        level = level[1]+1,
                        lat = int(lat),
                        lon = int(lon),
                        pc_alpha = pc_alpha,
                        experiment = experiment
                ))
                if not overwrite and plot_file.is_file():
                    print(f"Found file {plot_file}, skipping.")
                    continue # Ignore this result
                if len(alpha_result) > 0:
                    print(f"Plotting to {plot_file}")
                    plot_links(alpha_result, save_name = plot_file)
                else:
                    print("Results are empty, skipping.")
            time_plot_links = datetime.timedelta(
                    seconds = time.time() - t_before_plot_linkgs
            )
            print(f"Plotted. Time: {time_plot_links}")

    time_point = datetime.timedelta(
            seconds = time.time() - t_start_gridpoint)
    print(f"All links in gridpoint plotted. Time: {time_point}")
total_time = datetime.timedelta(seconds = time.time() - t_start)
print(f"Execution complete. Total time: {total_time}")
    