## 1. Load asssociated libraries

In [None]:
#Uncomment lines below to apply causal discovery
#Make sure that Tigramite is installed in your environment
################### Start load tigramite ####################
#import tigramite
#from tigramite import data_processing as pp
#from tigramite import plotting as tp
#from tigramite.pcmci import PCMCI
#from tigramite.independence_tests import ParCorr
################### End load tigramite ####################
from matplotlib import pyplot as plt
import numpy as np
import os
from datetime import datetime as dt
today = dt.now()
############################Load helper functions##########################################
from detrending import detrending
############################End of helper functions##########################################

## 2. Define asssociated parameters

In [None]:
#list of analysed models
model_names = ['ACCESS-CM2', 'ACCESS-ESM1-5', 'BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CMCC-CM2-SR5', 'CMCC-ESM2', 'EC-Earth3',
              'EC-Earth3-Veg',  'EC-Earth3-Veg-LR',  'GFDL-CM4', 'GFDL-ESM4',  'INM-CM4-8', 'INM-CM5-0', 'IPSL-CM6A-LR', 
              'MIROC6','MPI-ESM1-2-HR', 'MPI-ESM1-2-LR', 'MRI-ESM2-0']
#list of used variables
variables = ['PV', 'TAS', 'vflux', 'Sib-SLP', 'Ural-SLP','Aleut-SLP', 'NAO', 'U', 'BK-SIC', 'Ok-SIC']
#max time lag in months
max_timelag = 5
#list of seasons to apply masking
masking_list = ["OND", "DJF", "JFM"]
#define one or several pc_alpha values in a list (for the sensitivity tests)
pc_alpha_list = [0.01]#, 0.02, 0.05]

## 3. Create output folders
The output is sorted into the following folders: 

`dictionaries`  - Dictionary with the output from the application of PCMCI+; <br>
`txt`           -  Summary of matching causal and contemporaneous links between observations and CMIP6 models saved into txt files;  <br>
`CG`            - Causal graphs  from analysed models; <br>
`Summary plots` - Summary of causal and contemporaneous links from analysed models.

In [None]:
base_folder = "/path/to/the/output/folder/"
#Below we create several folders to sort the results based on date of the run
os.makedirs (base_folder + "dictionaries/" + today.strftime("%d%m%Y")+ "/", exist_ok=True)
os.makedirs (base_folder + "txt/" + today.strftime("%d%m%Y")+ "/", exist_ok=True)

for masking in masking_list:
    for pc_element in pc_alpha_list:
        # CG stands for Causal Graphs 
        os.makedirs (base_folder + "CG/" + today.strftime('%d%m%Y')+ '/' + masking + '/' + str(pc_element), exist_ok=True)
        os.makedirs (base_folder + "Summary_plots/" + today.strftime('%d%m%Y')+ '/' + masking + '/' + str(pc_element), exist_ok=True)
#################End prepare folders for the output ####################

## 4. Load variables 
We suggest to load data from Observations and CMIP6 models into the dictionary with the following structure:

In [None]:
dictionary = {
  "OBS": {"variable1": "values", "variable2" : "values", "variableN" : "values"},
  "Model1": {"variable1": "values", "variable2" : "values", "variableN" : "values"},
  "ModelN": {"variable1": "values", "variable2" : "values", "varieableN" : "values"},
}

## 5. Prepare data for Tigramite
5.1. Detrend each variable

In [None]:
for key in dictionary.keys ():
    for var in variables:
        dictionary[key][var]          = detrending (dictionary[key][var])

5.2 Construct an array from detrended data for causal discovery

In [None]:
# define below the **time** variable
data= np.zeros(( len (dictionary.keys ()), len (time),  len (variables)))
for i,val in enumerate (dictionary.keys ()):
    for j, val2 in enumerate (actors):
        data[i, :, j] =  dictionary[val][val2]
M, T, N = data.shape
print("# Models   Data Length   No of Actors")
print( '   ',  M, '     ' ,T, '           ', N)  

## 6. Calculate Causal Graph for each data source. 
The routine below can be used to reproduce Fig. 3, S1, S4 of Galytska et al., 2022, JGR

In [None]:
#Do you want to plot original causal graph from each data source?
plot_Causal_Graphs = True 
#Do you want to save the dictionary with from the application of Causal Discovery?
save_orig_dict = True 

# The causal graphs from different data sources will be saved into a new dictionary 
dict_networks= {}
for masking in masking_list:
    data_mask = np.zeros(data.shape)
    if masking == "OND":
        print ('Applying mask', masking)
        data_masking= np.where(np.logical_and(month > 0, month < 10), 1,0) #!1 stands for data, which will be masked
    elif masking == "DJF":
        print ('Applying mask', masking)
        data_masking= np.where(np.logical_and(month > 2, month < 12), 1,0) #!1 stands for data, which will be masked
    elif masking == "JFM":
        print ('Applying mask', masking)
        data_masking= np.where(np.logical_and(month > 3, month < 13), 1,0) #!1 stands for data, which will be masked    
    for m in range (0, M):
        for n in range (0, N):
            data_mask[m, :, n] = data_masking
    for i,key in enumerate (dictionary.keys ()):
        for pc_element in pc_alpha_list:
            print ("Calculating pc_alpha = ", pc_element)

            dict_networks.setdefault(key,{})
            dict_networks[key].setdefault (masking, {})
            dict_networks[key][masking].setdefault (pc_element, {})

            dataframe = pp.DataFrame(data[i, :, :], var_names = variables, mask = data_mask[i, :, :])
            dataframe.mask = (dataframe.mask == True) # true vals will be masked
            parcorr = ParCorr(significance='analytic', mask_type = 'y')
            pcmci = PCMCI(dataframe=dataframe, cond_ind_test=parcorr, verbosity=0)
            results = pcmci.run_pcmciplus(tau_max=max_timelag, pc_alpha=pc_element)
            q_matrix = pcmci.get_corrected_pvalues(p_matrix=results['p_matrix'], tau_max=max_timelag, fdr_method='fdr_bh')
            
            dict_networks[key][masking][pc_element].setdefault ('results', results)
            # Uncomment code below to plot causal graphs 
            
#            if plot_Causal_Graphs: 
#                tp.plot_graph(
#                               val_matrix=results['val_matrix'],
#                               graph=results['graph'],
#                               var_names=var_names,
#                               figsize = (8,8),
#                               node_pos = {'x': np.array([ 6.0,  6.3, 6.0, 10, 9.4, 3.8, 1.5, 3.1, 7.3, 7.9]),
#                                           'y': np.array([ 10.0, 3.9, 7.0, 5.2, 0.8, 7.0, 0.0,5.5, 2.4, 7.2])},
#                               node_size=0.7,       
#                               link_colorbar_label='cross-MCI',
#                               node_colorbar_label='auto-MCI', 
#                               node_label_size = 13,
#                               link_label_fontsize = 13, 
#                               arrow_linewidth = 11, 
#                       save_name = base_folder + "CG/" + today.strftime('%d%m%Y') + '/'+ masking + '/' + str(pc_element) + '/'+ key + '_pcalpha_'+str (pc_element)+"_" +str(len(var_names))+'actors_'+masking+'.png'),
#                plt.show()
                
            if save_orig_dict:
                np.save(base_folder + "dictionaries/" + today.strftime('%d%m%Y') + '/'+ 'dict_causal_graphs_'+ str(len(var_names))+'actors.npy', dict_networks)
