# SingleNNs Diagnostics

This script runs a number of diagnostics that evaluate the performance of both, SingleNNs & CausalSingleNNs offline (i.e., using test data).

### Setup

In [None]:
import sys
from utils.setup import SetupDiagnostics

# argv  = sys.argv[1:]
argv  = ["-c", "./nn_config/cfg_SingleNNs_Diagnostics.yml"]

setup = SetupDiagnostics(argv)


One can check both, attributes and methods in **setup** by, for example:\
**dir(setup)\
setup.__dict__** # setup.__dict__.keys() & setup.__dict__.values()

### Load Neural Networks

In [None]:
from pathlib import Path
def get_path(setup, model_type, *, pc_alpha=None, threshold=None):
    """ Generate a path based on this model metadata """
    path = Path(setup.nn_output_path, model_type)
    if model_type == "CausalSingleNN":
        path = path / Path(
            "a{pc_alpha}-t{threshold}/".format(
                pc_alpha=pc_alpha, threshold=threshold
            )
        )
    str_hl = str(setup.hidden_layers).replace(", ", "_")
    str_hl = str_hl.replace("[", "").replace("]", "")
    path = path / Path(
        "hl_{hidden_layers}-act_{activation}-e_{epochs}/".format(
            hidden_layers=str_hl,
            activation=setup.activation,
            epochs=setup.epochs,
        )
    )
    return path

In [None]:
from utils.variable import Variable_Lev_Metadata
def get_filename(setup, output):
    """ Generate a filename to save the model """
    i_var   = setup.output_order.index(output.var)
    i_level = output.level_idx
    if i_level is None:
        i_level = 0
    return f"{i_var}_{i_level}"

In [None]:
def get_save_plot_folder(setup, model_type, output, *, pc_alpha=None, threshold=None):
    folder = get_path(setup, model_type, pc_alpha=pc_alpha, threshold=threshold)
    path   = Path(folder, 'diagnostics')
    return path

In [None]:
from tensorflow.keras.models import load_model
def get_model(setup, output, model_type, *, pc_alpha=None, threshold=None):
    """ Get model and input list """
    folder    = get_path(setup, model_type, pc_alpha=pc_alpha, threshold=threshold)
    filename  = get_filename(setup, output)
    
    modelname = Path(folder,filename+'_model.h5')
    print(f"Load model: {modelname}")
    model     = load_model(modelname, compile=False)
    
    inputs_path = Path(folder, f"{filename}_input_list.txt")
    with open(inputs_path) as inputs_file:
        input_indices = [i for i, v in enumerate(inputs_file.readlines()) if int(v)]

    return (model, input_indices)

In [None]:
def get_var_list(setup, target_vars):
    output_list = list()
    for spcam_var in target_vars:
        if spcam_var.dimensions == 3:
            var_levels = [setup.children_idx_levs,setup.parents_idx_levs]\
            [spcam_var.type == 'in']
            for level, _ in var_levels:
                # There's enough info to build a Variable_Lev_Metadata list
                # However, it could be better to do a bigger reorganization
                var_name = f"{spcam_var.name}-{round(level, 2)}"
                output_list.append(var_name)
        elif spcam_var.dimensions == 2:
            var_name = spcam_var.name
            output_list.append(var_name)
    return output_list

In [None]:
from utils.variable import Variable_Lev_Metadata
import collections
def load_models(setup):
    """ Load all NN models specified in setup """
    models = collections.defaultdict(dict)
    
    output_list = get_var_list(setup, setup.spcam_outputs)
    if setup.do_single_nn:
        for output in output_list:
            output = Variable_Lev_Metadata.parse_var_name(output)
            models['SingleNN'][output] = get_model(
                setup, 
                output, 
                'SingleNN',
                pc_alpha=None,
                threshold=None
            )
    if setup.do_causal_single_nn:
        for pc_alpha in setup.pc_alphas:
            models['CausalSingleNN'][pc_alpha] = {}
            for threshold in setup.thresholds:
                models['CausalSingleNN'][pc_alpha][threshold] = {}
                for output in output_list:
                    output = Variable_Lev_Metadata.parse_var_name(output)
                    models['CausalSingleNN'][pc_alpha][threshold][output] = get_model(
                        setup, 
                        output, 
                        'CausalSingleNN',
                        pc_alpha=pc_alpha, 
                        threshold=threshold
                    )
                    
    return models

In [None]:
models = load_models(setup)

### Model Diagnostics

In [None]:
from utils.variable                        import Variable_Lev_Metadata
from neural_networks.data_generator        import build_valid_generator
from neural_networks.cbrain.utils          import load_pickle
import numpy                               as     np
import matplotlib.pyplot                   as     plt

cThemes = {'tphystnd':'coolwarm',
           'phq':'coolwarm',
           'fsns':'Reds',
           'flns':'Reds',
           'fsnt':'Reds',
           'flnt':'Reds',
           'prect':'PuBu'}

class ModelDiagnostics():
    def __init__(self, setup, models, nlat=64, nlon=128, nlev=30, ntime=48):

        self.nlat, self.nlon, self.nlev = nlat, nlon, nlev
        self.ngeo                       = nlat * nlon
        self.setup                      = setup
        self.models                     = models

    def reshape_ngeo(self, x):
        return x.reshape(self.nlat, self.nlon, -1)

    def get_output_var_idx(self, var):
        var_idxs = self.valid_gen.norm_ds.var_names[self.valid_gen.output_idxs]
        var_idxs = np.where(var_idxs == var)[0]
        return var_idxs
    

    def get_truth_pred(self, itime, var, nTime=False):
        
        input_list  = get_var_list(self.setup, self.setup.spcam_inputs)
        self.inputs = sorted(
            [Variable_Lev_Metadata.parse_var_name(p) for p in input_list],
            key=lambda x: self.setup.input_order_list.index(x),
        )
        self.input_vars_dict  = ModelDiagnostics._build_vars_dict(self.inputs)
        
        self.output = Variable_Lev_Metadata.parse_var_name(var)
        self.output_vars_dict = ModelDiagnostics._build_vars_dict([self.output])
        
        self.valid_gen       = build_valid_generator(
            self.input_vars_dict, 
            self.output_vars_dict, 
            self.setup
        )
        with self.valid_gen as valid_gen:
            
            model, inputs = self.models[var]
            
            if isinstance(itime, int):
                X, truth = valid_gen[itime]
                pred = model.predict_on_batch(X[:, inputs])
            
            elif itime == 'mean':
                if not nTime:
                    nTime = len(self.valid_gen)
                truth = np.zeros([nTime,self.ngeo,1])
                pred  = np.zeros([nTime,self.ngeo,1])
                for iTime in range(nTime):
                    X_tmp, truth[iTime,:] = valid_gen[iTime]
                    pred[iTime,:] = model.predict_on_batch(X_tmp[:, inputs])
                truth = np.mean(truth,axis=0)
                pred  = np.mean(pred,axis=0)
            
            # Inverse transform
            truth = valid_gen.output_transform.inverse_transform(truth)
            pred = valid_gen.output_transform.inverse_transform(pred)

        var_idxs = self.get_output_var_idx(var.var.ds_name)

        truth = truth[:, var_idxs]
        pred  = pred[:, var_idxs]
        
        return self.reshape_ngeo(truth), self.reshape_ngeo(pred)

    
    # Plotting functions
    def plot_double_xy(
        self, 
        itime, 
        var,
        nTime=None, 
        save=None, 
        diff=None,
        **kwargs
    ):
        
        varname = var.var.value
        
        t, p = self.get_truth_pred(itime, var, nTime=nTime)
        """THIS COULD GO INTO get_truth_pred!"""
        if t.shape[2] == 1 and p.shape[2] == 1:
            t = t.reshape(t.shape[:2])
            p = p.reshape(p.shape[:2])
        
        return self.plot_slices(t, p, itime, varname=varname, save=save, diff=diff, **kwargs)


    def plot_double_yz(
        self, 
        itime, 
        ilon, 
        var,
        varkeys,
        nTime=None, 
        save=None, 
        diff=None,
        **kwargs
    ):
        
        varname = var.var.value
        
        # Allocate array
        truth = np.zeros([self.nlev, self.nlat])
        pred  = np.zeros([self.nlev, self.nlat])
        for var in varkeys:
            iLev = ModelDiagnostics._build_vars_dict([var])[var.var.value.upper()][0]
#            print(var, iLev)
            t, p = self.get_truth_pred(itime, var, nTime=nTime)
            """THIS COULD GO INTO get_truth_pred!"""
            if t.shape[2] == 1 and p.shape[2] == 1:
                t = t.reshape(t.shape[:2])
                p = p.reshape(p.shape[:2])
            if isinstance(ilon, int):
                truth[iLev,:] = t[:,ilon]
                pred[iLev,:]  = p[:,ilon]
            elif ilon == 'mean':
                truth[iLev,:] = np.mean(t, axis=1)
                pred[iLev,:]  = np.mean(p, axis=1)

        return self.plot_slices(truth, pred, itime, varname=varname, save=save, diff=diff, **kwargs)
    
    
    def plot_slices(
        self, 
        t, 
        p, 
        itime, 
        title='', 
        unit='', 
        varname='', 
        save=None,
        diff=None,
        **kwargs
    ):
        n_slices  = [3,2][diff == None]
        fig, axes = plt.subplots(1, n_slices, figsize=(12, 5))
        
        vmin = np.min([np.min(p),np.min(t)])
        vmax = np.max([np.max(p),np.max(t)])
        if varname in ['tphystnd','phq']:
            vlim = np.max([np.abs(vmin),np.abs(vmax)])/2.
            vmin = -vlim; vmax = vlim
        elif varname in ['fsns','fsnt','prect']:
            vmin = 0

        cmap = cThemes[varname]
        cmap_diff = 'coolwarm'
        
        vars_to_plot = [p, t, p-t]
        labs_to_plot = ['Prediction', 'SPCAM', 'Prediction - SPCAM']
        for iSlice in range(n_slices):
            var_to_plot = vars_to_plot[iSlice]
            lab_to_plot = labs_to_plot[iSlice]
            I  = axes[iSlice].imshow(
                var_to_plot, 
#                 vmin=[vmin,None][iSlice==2],
#                 vmax=[vmax,None][iSlice==2],
                vmin=vmin,
                vmax=vmax, 
                cmap=[cmap,cmap_diff][iSlice==2], 
                **kwargs
            )
            cb = fig.colorbar(I, ax=axes[iSlice], orientation='horizontal')
            cb.set_label(unit)
            axes[iSlice].set_title(lab_to_plot)
        
        fig.suptitle(title)
        if save is not None:
            Path(save).mkdir(parents=True, exist_ok=True)
            fig.savefig(f"{save}/{var}_map_time-{itime}.png")
        return fig, axes
    
    
    @staticmethod
    def _build_vars_dict(list_variables):
        """ Convert the given list of Variable_Lev_Metadata into a
        dictionary to be used on the data generator.
        
        Parameters
        ----------
        list_variables : list(Variable_Lev_Metadata)
            List of variables to be converted to the dictionary format
            used by the data generator
        
        Returns
        -------
        vars_dict : dict{str : list(int)}
            Dictionary of the form {ds_name : list of levels}, where
            "ds_name" is the name of the variable as stored in the
            dataset, and "list of levels" a list containing the indices
            of the levels of that variable to use, or None for 2D
            variables.
        """
        vars_dict = dict()
        for variable in list_variables:
            ds_name = variable.var.ds_name  # Name used in the dataset
            if variable.var.dimensions == 2:
                vars_dict[ds_name] = None
            elif variable.var.dimensions == 3:
                levels = vars_dict.get(ds_name, list())
                levels.append(variable.level_idx)
                vars_dict[ds_name] = levels
        return vars_dict


### Cross-section plots

In [None]:
model_type = 'SingleNN'
md = ModelDiagnostics(setup = setup, models=models[model_type])
vars_to_ploted = []
for var in models[model_type].keys():
    if var.var.value not in vars_to_ploted and var.var.dimensions == 3:
        print(var.var.value)
        outPath = get_save_plot_folder(setup, model_type, var.var.value)
        var_keys = [v for v in models[model_type].keys() if var.var.value in str(v)]
        md.plot_double_yz(100, 100, var, var_keys, nTime=False, diff=True)
#        md.plot_double_yz('mean', 'mean', var, var_keys, nTime=30, diff=True)
        vars_to_ploted.append(var.var.value)
        plt.show()

In [None]:
model_type = 'CausalSingleNN'
vars_to_ploted = []
for pc_alpha in models[model_type].keys():
    print(f"pc_alpha: {pc_alpha}")
    for threshold in models[model_type][pc_alpha].keys():
        print(f"threshold: {threshold}")
        md = ModelDiagnostics(setup = setup, models=models[model_type][pc_alpha][threshold])
        for var in models[model_type][pc_alpha][threshold].keys():
            if var.var.value not in vars_to_ploted and var.var.dimensions == 3:
                print(var.var.value)
                outPath = get_save_plot_folder(setup, model_type, var.var.value)
                var_keys = [v for v in models[model_type][pc_alpha][threshold].keys() \
                            if var.var.value in str(v)]
                md.plot_double_yz(100, 100, var, var_keys, nTime=False, diff=True)
#                md.plot_double_yz('mean', 'mean', var, var_keys, nTime=30, diff=True)
                vars_to_ploted.append(var.var.value)
                plt.show()

### Map plots

In [None]:
model_type = 'SingleNN'
md = ModelDiagnostics(setup = setup, models=models[model_type])
for var in models[model_type].keys():
    print(var)
    outPath = get_save_plot_folder(setup, model_type, var)
#    md.plot_double_xy('mean', var, nTime=False, diff=True, save=outPath)
    md.plot_double_xy(100, var, nTime=False, diff=True)#, save=outPath)
    plt.show()

In [None]:
model_type = 'CausalSingleNN'
for pc_alpha in models[model_type].keys():
    print(f"pc_alpha: {pc_alpha}")
    for threshold in models[model_type][pc_alpha].keys():
        print(f"threshold: {threshold}")
        md = ModelDiagnostics(setup = setup, models=models[model_type][pc_alpha][threshold])
        for var in models[model_type][pc_alpha][threshold].keys():
            print(f"variable: {var}\n")
            md.plot_double_xy(100, var)
            plt.show()