# Performance of the hierarchy

In this notebook we analyze the performance on the different models in the hierarchy. We do so for different training dataset length and regularization coefficient.

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget
matplotlib.rc('font', size=18)
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

import xarray as xr
import pandas as pd

import sys
sys.path.append('../Climate-Learning/')

import general_purpose.uplotlib as uplt
import general_purpose.tables as tbl

HOME = './'

In [None]:
hierarchy = ['GA', 'IINN', 'ScatNet', 'CNN']

dataset = 'Va' # validation
# dataset = 'Te' # test

year_suffix = '' # 800 years of training
# year_suffix = '80y-' # 80 years of training

def load_data(dataset, year_suffix):
    skills = {mo : xr.open_dataset(f'{dataset}-{year_suffix}Skill_{mo}.nc') for mo in hierarchy}
    skills_av = {mo: uplt.xr_avg(ds, 'fold') for mo,ds in skills.items()} # take the average over the 5 folds
    return skills, skills_av

skills, skills_av = load_data(dataset, year_suffix)

## Pareto plots

Pareto plots are useful when we want to optimize two things at once. In our case we want the skills to be as high as possible and the projection patterns as smooth as possible (low $H_2$).

This means finding an otimum of the regularization coefficient of GA and IINN.
In the Pareto plots, ScatNet and CNN will have constant skill since they don't have a regularization coefficient nor a projection pattern

For GA and IINN we can regularize the projection pattern by penalizing its L2 norm or its H2 norm. In the paper we use to the second, which gives better results in principle and in practice, but here we show both.

In [None]:
linestyles = [None, 'dashed']
for j,(metr_name,metr) in enumerate(skills_av['CNN'].data_vars.items()):
    plt.close(13+j)
    fig,ax = plt.subplots(figsize=(8,6), num=13+j)
    for i,reg in enumerate(skills_av['GA']['regularization']):
        for h,(tech,d) in enumerate(skills_av.items()):
            if tech in ['CNN', 'ScatNet']:
                continue
            sel = d[metr_name].sel(regularization=reg, drop=True)
            h2 = d['h2'].sel(regularization=reg, drop=True)
            uplt.plot(h2.values*(1 + 0.01*(2*j + i)),
                      sel.values, linestyle=linestyles[i], color=default_colors[h], label=f'{tech} ({reg.values})')

    plt.xscale('log')
    ax.set_xlim(*ax.get_xlim())
    uplt.errorband(ax.get_xlim(), [metr.values.item()]*2, color='gray', label='CNN')
    uplt.errorband(ax.get_xlim(), [skills_av['ScatNet'][metr_name].values.item()]*2, color='purple', label='ScatNet', band_alpha=0.2)

    if year_suffix == '':
        ax.set_ylim(0.15,0.35)
    else:
        ax.set_ylim(0,0.5)

    if j == 1:
        plt.legend()

    plt.xlabel(r'$H_2$')
    plt.ylabel(metr_name)

    plt.grid(axis='y')

    fig.tight_layout()

    # fig.savefig(f'{HOME}/pareto-{year_suffix}{dataset}-{metr_name}.pdf')

### Make it a single figure

In [None]:
def pareto_plot(skills_av, save=False):
    matplotlib.rc('font', size=24)
    plt.close(13)
    fig, axs = plt.subplots(1,3, num=13, figsize=(24,8))

    linestyles = [None, 'dashed']
    for j,(metr_name,metr) in enumerate(skills_av['CNN'].data_vars.items()):
        ax = axs[j]
        for i,reg in enumerate(skills_av['GA']['regularization']):
            for h,(tech,d) in enumerate(skills_av.items()):
                if tech in ['CNN', 'ScatNet']:
                    continue
                sel = d[metr_name].sel(regularization=reg, drop=True)
                h2 = d['h2'].sel(regularization=reg, drop=True)
                uplt.plot(h2.values*(1 + 0.01*(2*j + i)),
                        sel.values, linestyle=linestyles[i], ax=ax, color=default_colors[h], label=f'{tech} ({reg.values})')

        ax.set_xscale('log')
        ax.set_xlim(*ax.get_xlim())
        uplt.errorband(ax.get_xlim(), [metr.values.item()]*2, ax=ax, color='gray', label='CNN')
        uplt.errorband(ax.get_xlim(), [skills_av['ScatNet'][metr_name].values.item()]*2, ax=ax, color='purple', label='ScatNet', band_alpha=0.2)

        if year_suffix == '':
            ax.set_ylim(0.15,0.35)
        else:
            ax.set_ylim(0,0.5)

        if j == 1:
            ax.legend()

        ax.set_xlabel(r'$H_2$')
        ax.set_ylabel(metr_name)

        ax.grid(axis='y')

    title = ('Test' if dataset == 'Te' else 'Validation') + ' skills when training on ' + ('64' if year_suffix else '640') + ' years'
    fig.suptitle(title)

    fig.tight_layout()

    if save:
        fig.savefig(f'{HOME}/pareto-{year_suffix}{dataset}-all.pdf')

    

pareto_plot(skills_av)

### Save all figures at once

In [None]:
for year_suffix in ['', '80y-']:
    for dataset in ['Va', 'Te']:
        print(f'{year_suffix}{dataset}')
        skills, skills_av = load_data(dataset, year_suffix)
        pareto_plot(skills_av, save=True)

### Table at the pareto optimum

The pareto optimum for GA and IINN is the value of the regularization coefficient $\epsilon$ that gives the highest validation skill while also having an $H_2$ norm as small as possible.

Since there is a broad plateau, we choose the best value by eye, rather than doing some precise optimization.

In [None]:
if year_suffix == '':
    pareto_optimum = {'GA': 100, 'IINN': 0.1}
else:
    pareto_optimum = {'GA': 1000, 'IINN': 1}

# create dataset with axes method and metric
metrics = list(skills_av['CNN'].data_vars.keys())
df = {}
for mo,ds in skills_av.items():
    if mo in pareto_optimum:
        sel = ds.sel(reg_c=pareto_optimum[mo], regularization='gradient', drop=True)
    else:
        sel = ds
    df[mo] = {m: sel[m].data.item() for m in metrics}

df = pd.DataFrame(df).T
df

In [None]:
_ = tbl.table(df, cmap=None)

In [None]:
_ = tbl.tex_table(df, cmap=None, xlabel='Metric', ylabel='Model', leading_indentation=2, close_left=False, close_top=False, use_midrule=False)
print(_)

### Check skills fold-wise

In [None]:
skills['IINN'].sel(regularization='gradient', reg_c=pareto_optimum['IINN'])