# Offline diagnostics for the CASTLE single output networks following Rasp et al. (2018) architecture

# Stats Computation

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

# 0 = all messages are logged (default behavior)
# 1 = INFO messages are not printed
# 2 = INFO and WARNING messages are not printed
# 3 = INFO, WARNING, and ERROR messages are not printe
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 

In [3]:
module_path = os.path.abspath(os.path.join('..'))
# Relative imports
if module_path not in sys.path:
    sys.path.append(module_path)

In [4]:
import tensorflow as tf

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  4


In [5]:
from utils.setup import SetupDiagnostics
from neural_networks.load_models import load_models, get_save_plot_folder, load_single_model
from neural_networks.model_diagnostics import ModelDiagnostics
from pathlib import Path
from utils.variable import Variable_Lev_Metadata
import pickle
import gc

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


## Load trained CASTLE models

In [6]:
PROJECT_ROOT = Path.cwd().parent.resolve()
base_dir = os.path.join(PROJECT_ROOT, "output_castle", "training_40_gumbel_softmax_single_output_spars0.1") 

cfg = os.path.join(base_dir, "cfg_gumbel_softmax_single_output.yml")
argv  = ["-c", cfg]
plot_dir = Path(base_dir, "plots_offline_evaluation", "stats")

In [8]:
# argv  = ["-c", "../output_castle/eval_nando/single_nn/cfg_single_nn_diagnostics.yml"]
# plot_dir = Path("../output_castle/eval_nando/single_nn/plots_offline_evaluation/stats")

# argv  = ["-c", "../output_castle/eval_nando/causal_single_nn/cfg_causal_single_nn_diagnostics.yml"]
# plot_dir = Path("../output_castle/eval_nando/causal_single_nn/plots_offline_evaluation/stats")


In [9]:
Path(plot_dir).mkdir(parents=True, exist_ok=True)

In [10]:
setup = SetupDiagnostics(argv)


Set leaky relu alpha to 0.3



### Load only models for one variables

In [11]:
var_name = "tphystnd-691.39"

In [15]:
models = load_single_model(setup, var_name)
single_nn = True


Load model: /work/bd1179/b309247/pycharm_projects/iglesias-suarez2yxx_spuriouslinks/output_castle/training_28_custom_mirrored_functional/models_castle/run_1/castleNN/r1.0-a1.0-b1.0-l1.0-mirrored/hl_256_256_256_256_256_256_256_256_256-act_ReLU-e_18/1_20_model.keras


### Load all models

In [12]:
models = load_models(setup, skip_causal_phq=True)
single_nn = False


Load model: /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/output_castle/training_40_gumbel_softmax_single_output_spars0.1/models/GumbelSoftmaxSingleOutputModel/lspar0.1/hl_256_256_256_256_256_256_256_256_256-act_LeakyReLU_0.3-e_50/1_0_model.keras


Lambda sparsity = 0.1


Load model: /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/output_castle/training_40_gumbel_softmax_single_output_spars0.1/models/GumbelSoftmaxSingleOutputModel/lspar0.1/hl_256_256_256_256_256_256_256_256_256-act_LeakyReLU_0.3-e_50/1_1_model.keras


Lambda sparsity = 0.1


Load model: /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/output_castle/training_40_gumbel_softmax_single_output_spars0.1/models/GumbelSoftmaxSingleOutputModel/lspar0.1/hl_256_256_256_256_256_256_256_256_256-act_LeakyReLU_0.3-e_50/1_2_model.keras


Lambda sparsity = 0.1


Load model: /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/ou

In [13]:
model_key = setup.nn_type

In [14]:
# Note: keys are variables not strings
if setup.nn_type == "CausalSingleNN":
    print(models[model_key][setup.pc_alphas[0]][setup.thresholds[0]].keys())
else: 
    if single_nn: 
        print(models.keys())
    else: 
        print(models[model_key].keys())

dict_keys(['tphystnd-3.64', 'tphystnd-7.59', 'tphystnd-14.36', 'tphystnd-24.61', 'tphystnd-38.27', 'tphystnd-54.6', 'tphystnd-72.01', 'tphystnd-87.82', 'tphystnd-103.32', 'tphystnd-121.55', 'tphystnd-142.99', 'tphystnd-168.23', 'tphystnd-197.91', 'tphystnd-232.83', 'tphystnd-273.91', 'tphystnd-322.24', 'tphystnd-379.1', 'tphystnd-445.99', 'tphystnd-524.69', 'tphystnd-609.78', 'tphystnd-691.39', 'tphystnd-763.4', 'tphystnd-820.86', 'tphystnd-859.53', 'tphystnd-887.02', 'tphystnd-912.64', 'tphystnd-936.2', 'tphystnd-957.49', 'tphystnd-976.33', 'tphystnd-992.56', 'phq-3.64', 'phq-7.59', 'phq-14.36', 'phq-24.61', 'phq-38.27', 'phq-54.6', 'phq-72.01', 'phq-87.82', 'phq-103.32', 'phq-121.55', 'phq-142.99', 'phq-168.23', 'phq-197.91', 'phq-232.83', 'phq-273.91', 'phq-322.24', 'phq-379.1', 'phq-445.99', 'phq-524.69', 'phq-609.78', 'phq-691.39', 'phq-763.4', 'phq-820.86', 'phq-859.53', 'phq-887.02', 'phq-912.64', 'phq-936.2', 'phq-957.49', 'phq-976.33', 'phq-992.56', 'fsnt', 'fsns', 'flnt', 'fl

## Compute Stats

In [15]:
# This variable does not exist in the code (but key nn_type is the same)
setup.model_type = model_key

In [16]:
if setup.nn_type == "CausalSingleNN":
    md = ModelDiagnostics(setup=setup, 
                          models=models[model_key][setup.pc_alphas[0]][setup.thresholds[0]])
else: 
    if single_nn: 
         md = ModelDiagnostics(setup=setup, 
                          models=models) 
    else:
        md = ModelDiagnostics(setup=setup, 
                              models=models[model_key]) 

In [17]:
if setup.nn_type == "CausalSingleNN":
    dict_keys = models[model_key][setup.pc_alphas[0]][setup.thresholds[0]].keys()
else: 
    if single_nn: 
        dict_keys = models.keys()
    else:
        dict_keys = models[model_key].keys()

### Single Variable

In [18]:
def save_single_stats():
    f_name = f"{i_time}_{n_time}_{var}.p"
    with open(os.path.join(plot_dir, f_name), "wb") as f:
        pickle.dump(md.stats, f)

In [27]:
var_name = "tphystnd-820.86" # "tphystnd-820.86", "tphystnd-691.39"
var = Variable_Lev_Metadata.parse_var_name(var_name)

In [28]:
i_time = "range" # this has to be range for it to work
n_time = 1440

md.compute_stats(i_time, var, nTime=n_time)
md.stats


Computing horizontal stats for variable tphystnd-820.86

Test batch size = 8192.

Opening dataset /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc

opening as h5py /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc
Time samples: 1440


{'hor_tsqmean': 8.378718249991333e-09,
 'hor_tmean': -2.7195295818880438e-06,
 'hor_mse': 4.714877110976323e-09,
 'hor_tvar': 8.371322408844569e-09,
 'hor_r2': 0.4367822811369796}

In [29]:
save_single_stats()

In [30]:
def load_single_stats():
    f_name = f"{i_time}_{n_time}_{var}.p"
    with open(os.path.join(plot_dir, f_name), "rb") as f:
        var_stats = pickle.load(f)
    return var_stats

In [31]:
var_stats = load_single_stats()

In [32]:
for key, value in var_stats.items():
    print(f"{key}: {value}  ")

hor_tsqmean: 8.378718249991333e-09  
hor_tmean: -2.7195295818880438e-06  
hor_mse: 4.714877110976323e-09  
hor_tvar: 8.371322408844569e-09  
hor_r2: 0.4367822811369796  


**tphystnd-691.39**  
hor_tsqmean: 6.202147166636002e-09  
hor_tmean: 5.2422255662459223e-08  
hor_mse: 1.9602593852212597e-09  
hor_tvar: 6.202144418543113e-09  
hor_r2: 0.68393844887576

**tphystnd-820.86**  
hor_tsqmean: 8.378718249991333e-09  
hor_tmean: -2.7195295818880438e-06  
hor_mse: 4.714877110976323e-09  
hor_tvar: 8.371322408844569e   
9hor_r2: 0.4367822811369796 796 39  7639

### All Variables

In [10]:
def get_save_str(idx_time, num_time=False):
    if type(idx_time) is int:
        idx_time_str = f"step-{idx_time}"
    elif type(idx_time) is str:
        if num_time:
            idx_time_str = f"{idx_time}-{num_time}"
        else:
            idx_time_str = f"{idx_time}-all"
    else:
        raise ValueError(f"Unkown value for idx_time: {idx_time}")
    
    return idx_time_str
        

In [20]:
# Not function parameters, uses variables that are set in Notebook cells!!
def run_compute_stats():
    save_dir = Path(plot_dir, get_save_str(i_time, num_time=n_time))
    Path(save_dir).mkdir(parents=True, exist_ok=True)
                    
    stats_dict = dict()

    for var in dict_keys:
        print(f"\n\n---- Variable {var}")
        md.compute_stats(i_time, var, nTime=n_time)
        stats_dict[str(var)] = md.stats
        
        gc.collect()
                    
    f_name = f"hor_stats.p"
    with open(os.path.join(save_dir, f_name), "wb") as f:
        pickle.dump(stats_dict, f)
    

#### Time step range 

In [21]:
print(f"\nBase directory: {base_dir}")

i_time = "range"
n_time = 1440

run_compute_stats()


Base directory: /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/output_castle/training_40_gumbel_softmax_single_output_spars0.1


---- Variable tphystnd-3.64

Computing horizontal stats for variable tphystnd-3.64

Test batch size = 8192.

Opening dataset /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc

opening as h5py /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc
Time samples: 1440


---- Variable tphystnd-7.59

Computing horizontal stats for variable tphystnd-7.59

Test batch size = 8192.

Opening dataset /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc

opening as h5py /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc
Time samples: 1440


---- Variable tphystnd-14.3

divide by zero encountered in divide
divide by zero encountered in double_scalars




---- Variable phq-7.59

Computing horizontal stats for variable phq-7.59

Test batch size = 8192.

Opening dataset /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc

opening as h5py /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc
Time samples: 1440


---- Variable phq-14.36

Computing horizontal stats for variable phq-14.36

Test batch size = 8192.


divide by zero encountered in divide
divide by zero encountered in double_scalars



Opening dataset /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc

opening as h5py /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc
Time samples: 1440


---- Variable phq-24.61

Computing horizontal stats for variable phq-24.61

Test batch size = 8192.

Opening dataset /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc

opening as h5py /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc
Time samples: 1440


---- Variable phq-38.27

Computing horizontal stats for variable phq-38.27

Test batch size = 8192.

Opening dataset /p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/SPCAM_tb_preproc/2021_09_02_TEST_For_Nando.nc

opening as h5py /p/project/icon-a-ml/kuehbacher1/projects/igle

KeyboardInterrupt: 

### Load and display

In [13]:
stats_dir = os.path.join(PROJECT_ROOT, "output_castle", "eval_nando", "causal_single_nn", "plots_offline_evaluation", "stats") 
# stats_dir = plot_dir

In [14]:
i_time = "range"
n_time = 1440

In [15]:
dict_file = Path(stats_dir, get_save_str(i_time, num_time=n_time), "hor_stats.p")
with open(dict_file, "rb") as f:
    stats_dict = pickle.load(f)

In [17]:
dict_file

PosixPath('/p/project/icon-a-ml/kuehbacher1/projects/iglesias-suarez2yxx_spuriouslinks/output_castle/eval_nando/causal_single_nn/plots_offline_evaluation/stats/range-1440/hor_stats.p')

In [16]:
for key, items in stats_dict.items():
    print(f"\n\nStats for variable {key}:")
    for s, v in items.items():
        print(f"{s}: {v}")



Stats for variable tphystnd-3.64:
hor_tsqmean: 2.310132993799386e-09
hor_tmean: 3.025768035469561e-06
hor_mse: 2.2490026345719144e-12
hor_tvar: 2.300977721594917e-09
hor_r2: 0.9990225882617355


Stats for variable tphystnd-7.59:
hor_tsqmean: 1.2246476527137705e-09
hor_tmean: 5.799085646043226e-07
hor_mse: 2.4385635412691245e-12
hor_tvar: 1.224311358770469e-09
hor_r2: 0.9980082161912489


Stats for variable tphystnd-14.36:
hor_tsqmean: 1.504981780182242e-09
hor_tmean: -1.072999353602802e-05
hor_mse: 4.3493885464898324e-10
hor_tvar: 1.3898490188990388e-09
hor_r2: 0.6870603578268396


Stats for variable tphystnd-24.61:
hor_tsqmean: 1.1085736701045974e-09
hor_tmean: 2.0490010612237182e-05
hor_mse: 2.1718089037521427e-10
hor_tvar: 6.887331352150051e-10
hor_r2: 0.6846661220860009


Stats for variable tphystnd-38.27:
hor_tsqmean: 9.87209322563402e-11
hor_tmean: -1.082172218870208e-06
hor_mse: 9.995964414050198e-13
hor_tvar: 9.754983554504574e-11
hor_r2: 0.9897529664112715


Stats for variab

In [None]:
for s, v in stats_dict[str(var)].items():
    print(f"{s}: {v}")

In [34]:
del models

In [35]:
del md

In [37]:
gc.collect()

1622