# Offline diagnostics for the CASTLE single output networks following Rasp et al. (2018) architecture

# Stats Computation

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

# 0 = all messages are logged (default behavior)
# 1 = INFO messages are not printed
# 2 = INFO and WARNING messages are not printed
# 3 = INFO, WARNING, and ERROR messages are not printe
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 

In [3]:
module_path = os.path.abspath(os.path.join('..'))
# Relative imports
if module_path not in sys.path:
    sys.path.append(module_path)

In [4]:
import tensorflow as tf

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [5]:
from utils.setup import SetupDiagnostics
from neural_networks.load_models import load_models, get_save_plot_folder, load_single_model
from neural_networks.model_diagnostics import ModelDiagnostics
from pathlib import Path
from utils.variable import Variable_Lev_Metadata
import pickle

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


## Load trained CASTLE models

In [6]:
argv  = ["-c", "../output_castle/training_28_custom_mirrored_functional/cfg_castle_training_run_1.yml"]
plot_dir = Path("../output_castle/training_28_custom_mirrored_functional/plots_offline_evaluation/run_1/stats_test/")

# argv  = ["-c", "../output_castle/eval_nando/single_nn/cfg_single_nn_diagnostics.yml"]
# plot_dir = Path("../output_castle/eval_nando/single_nn/plots_offline_evaluation/stats")

# argv  = ["-c", "../output_castle/eval_nando/causal_single_nn/cfg_causal_single_nn_diagnostics.yml"]
# plot_dir = Path("../output_castle/eval_nando/causal_single_nn/plots_offline_evaluation/stats")

Path(plot_dir).mkdir(parents=True, exist_ok=True)

In [7]:
setup = SetupDiagnostics(argv)

### Load only models for one variables

In [8]:
var_name = "tphystnd-691.39"

In [15]:
models = load_single_model(setup, var_name)
single_nn = True


Load model: /work/bd1179/b309247/pycharm_projects/iglesias-suarez2yxx_spuriouslinks/output_castle/training_28_custom_mirrored_functional/models_castle/run_1/castleNN/r1.0-a1.0-b1.0-l1.0-mirrored/hl_256_256_256_256_256_256_256_256_256-act_ReLU-e_18/1_20_model.keras


### Load all models

In [30]:
models = load_models(setup, skip_causal_phq=True)
single_nn = False


Load model: /work/bd1179/b309172/analysis/usmile/causality_convection/python/causalnncam/models_arch-rasp_thrs-opt-mse6/CausalSingleNN/a0.01-toptimized-latwts/hl_256_256_256_256_256_256_256_256_256-act_LeakyReLU-e_18/1_0_model.h5

Load model: /work/bd1179/b309172/analysis/usmile/causality_convection/python/causalnncam/models_arch-rasp_thrs-opt-mse6/CausalSingleNN/a0.01-toptimized-latwts/hl_256_256_256_256_256_256_256_256_256-act_LeakyReLU-e_18/1_1_model.h5

Load model: /work/bd1179/b309172/analysis/usmile/causality_convection/python/causalnncam/models_arch-rasp_thrs-opt-mse6/CausalSingleNN/a0.01-toptimized-latwts/hl_256_256_256_256_256_256_256_256_256-act_LeakyReLU-e_18/1_2_model.h5

Load model: /work/bd1179/b309172/analysis/usmile/causality_convection/python/causalnncam/models_arch-rasp_thrs-opt-mse6/CausalSingleNN/a0.01-toptimized-latwts/hl_256_256_256_256_256_256_256_256_256-act_LeakyReLU-e_18/1_3_model.h5

Load model: /work/bd1179/b309172/analysis/usmile/causality_convection/pytho

In [10]:
model_key = setup.nn_type

In [16]:
# Note: keys are variables not strings
if setup.nn_type == "CausalSingleNN":
    print(models[model_key][setup.pc_alphas[0]][setup.thresholds[0]].keys())
else: 
    if single_nn: 
        print(models.keys())
    else: 
        print(models[model_key].keys())

dict_keys(['tphystnd-691.39'])


## Compute Stats

In [17]:
# This variable does not exist in the code (but key nn_type is the same)
setup.model_type = model_key

In [18]:
if setup.nn_type == "CausalSingleNN":
    md = ModelDiagnostics(setup=setup, 
                          models=models[model_key][setup.pc_alphas[0]][setup.thresholds[0]])
else: 
    if single_nn: 
         md = ModelDiagnostics(setup=setup, 
                          models=models) 
    else:
        md = ModelDiagnostics(setup=setup, 
                              models=models[model_key]) 

In [19]:
if setup.nn_type == "CausalSingleNN":
    dict_keys = models[model_key][setup.pc_alphas[0]][setup.thresholds[0]].keys()
else: 
    if single_nn: 
        dict_keys = models.keys()
    else:
        dict_keys = models[model_key].keys()

### Single Variable

In [20]:
def save_single_stats():
    f_name = f"{i_time}_{n_time}_{var}.p"
    with open(os.path.join(plot_dir, f_name), "wb") as f:
        pickle.dump(md.stats, f)

In [21]:
var = Variable_Lev_Metadata.parse_var_name(var_name)

In [26]:
i_time = "range" # this has to be range for it to work
n_time = 1440

md.compute_stats(i_time, var, nTime=n_time)
md.stats


Computing horizontal stats for variable tphystnd-691.39

Test batch size = 8192.
Time samples: 1440


{'hor_tsqmean': 6.202147166636002e-09,
 'hor_tmean': 5.2422255662459223e-08,
 'hor_mse': 5.150198453787545e-09,
 'hor_tvar': 6.202144418543113e-09,
 'hor_r2': 0.1696100402967835}

In [26]:
save_single_stats()

In [27]:
def load_single_stats():
    f_name = f"{i_time}_{n_time}_{var}.p"
    with open(os.path.join(plot_dir, f_name), "rb") as f:
        var_stats = pickle.load(f)
    return var_stats

In [28]:
var_stats = load_single_stats()

In [29]:
var_stats

{'hor_tsqmean': 1.9285970748918967e-14,
 'hor_tmean': 3.8232942259022e-08,
 'hor_mse': 1.3778499503189315e-14,
 'hor_tvar': 1.7824212875137258e-14,
 'hor_r2': 0.226978515140562}

### All Variables

In [27]:
def get_save_str(idx_time, num_time=False):
    if type(idx_time) is int:
        idx_time_str = f"step-{idx_time}"
    elif type(idx_time) is str:
        if num_time:
            idx_time_str = f"{idx_time}-{num_time}"
        else:
            idx_time_str = f"{idx_time}-all"
    else:
        raise ValueError(f"Unkown value for idx_time: {idx_time}")
    
    return idx_time_str
        

In [37]:
# Not function parameters, uses variables that are set in Notebook cells!!
def run_compute_stats():
    save_dir = Path(plot_dir, get_save_str(i_time, num_time=n_time))
    Path(save_dir).mkdir(parents=True, exist_ok=True)
                    
    stats_dict = dict()

    for var in dict_keys:
        print(f"\n\n---- Variable {var}")
        md.compute_stats(i_time, var, nTime=n_time)
        stats_dict[str(var)] = md.stats
                    
    f_name = f"hor_stats.p"
    with open(os.path.join(save_dir, f_name), "wb") as f:
        pickle.dump(stats_dict, f)
    

#### Time step range 

In [38]:
i_time = "range"
n_time = 1440

run_compute_stats()



---- Variable tphystnd-3.64

Computing horizontal stats for variable tphystnd-3.64

Test batch size = 8192.
Time samples: 1440


---- Variable tphystnd-7.59

Computing horizontal stats for variable tphystnd-7.59

Test batch size = 8192.
Time samples: 1440


---- Variable tphystnd-14.36

Computing horizontal stats for variable tphystnd-14.36

Test batch size = 8192.
Time samples: 1440


---- Variable tphystnd-24.61

Computing horizontal stats for variable tphystnd-24.61

Test batch size = 8192.
Time samples: 1440


---- Variable tphystnd-38.27

Computing horizontal stats for variable tphystnd-38.27

Test batch size = 8192.
Time samples: 1440


---- Variable tphystnd-54.6

Computing horizontal stats for variable tphystnd-54.6

Test batch size = 8192.
Time samples: 1440


---- Variable tphystnd-72.01

Computing horizontal stats for variable tphystnd-72.01

Test batch size = 8192.
Time samples: 1440


---- Variable tphystnd-87.82

Computing horizontal stats for variable tphystnd-87.82

T

divide by zero encountered in divide


Time samples: 1440


---- Variable flnt

Computing horizontal stats for variable flnt

Test batch size = 8192.


divide by zero encountered in divide


Time samples: 1440


---- Variable flns

Computing horizontal stats for variable flns

Test batch size = 8192.
Time samples: 1440


---- Variable prect

Computing horizontal stats for variable prect

Test batch size = 8192.
Time samples: 1440


### Load and display

In [29]:
stats_dir= Path("../output_castle/training_28_custom_mirrored_functional/plots_offline_evaluation/run_1/stats/")


In [30]:
i_time = "range"
n_time = 1440

In [31]:
dict_file = Path(stats_dir, get_save_str(i_time, num_time=n_time), "hor_stats.p")
with open(dict_file, "rb") as f:
    stats_dict = pickle.load(f)

In [35]:
for s, v in stats_dict[str(var)].items():
    print(f"{s}: {v}")

hor_tsqmean: 6.202147166636002e-09
hor_tmean: 5.2422255662459223e-08
hor_mse: 5.150198453787545e-09
hor_tvar: 6.202144418543113e-09
hor_r2: 0.1696100402967835


In [40]:
for key, items in stats_dict.items():
    print(f"\nStats for variable {key}:")
    for s, v in items.items():
        print(f"{s}: {v}")


Stats for variable tphystnd-3.64:
hor_tsqmean: 2.310132993799386e-09
hor_tmean: 3.025768035469561e-06
hor_mse: 2.2490026345719144e-12
hor_tvar: 2.300977721594917e-09
hor_r2: 0.9990225882617355

Stats for variable tphystnd-7.59:
hor_tsqmean: 1.2246476527137705e-09
hor_tmean: 5.799085646043226e-07
hor_mse: 2.4385635412691245e-12
hor_tvar: 1.224311358770469e-09
hor_r2: 0.9980082161912489

Stats for variable tphystnd-14.36:
hor_tsqmean: 1.504981780182242e-09
hor_tmean: -1.072999353602802e-05
hor_mse: 4.3493885464898324e-10
hor_tvar: 1.3898490188990388e-09
hor_r2: 0.6870603578268396

Stats for variable tphystnd-24.61:
hor_tsqmean: 1.1085736701045974e-09
hor_tmean: 2.0490010612237182e-05
hor_mse: 2.1718089037521427e-10
hor_tvar: 6.887331352150051e-10
hor_r2: 0.6846661220860009

Stats for variable tphystnd-38.27:
hor_tsqmean: 9.87209322563402e-11
hor_tmean: -1.082172218870208e-06
hor_mse: 9.995964414050198e-13
hor_tvar: 9.754983554504574e-11
hor_r2: 0.9897529664112715

Stats for variable tph

In [None]:
del castle_models

In [None]:
del castle_md