# Committor functions

In this notebook we analyze the skills and optimal projection patterns of the Gaussian approximation applied to the PlaSim dataset

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget
matplotlib.rc('font', size=18)
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']


import pandas as pd
import xarray as xr
from scipy import sparse

from tqdm.notebook import tqdm

import sys
sys.path.append('../../../Climate-Learning/')

import general_purpose.utilities as ut
import general_purpose.cartopy_plots as cplt
import general_purpose.uplotlib as uplt
import general_purpose.tables as tbl

# log to stdout
import logging
logging.getLogger().level = logging.INFO
logging.getLogger().handlers = [logging.StreamHandler(sys.stdout)]
ut.indentation_sep = '  '

HOME = '../../'

In [None]:
lon = np.load('../../lon.npy')
lat = np.load('../../lat.npy')
LON, LAT = np.meshgrid(lon,lat)

## Figure 5

In [None]:
mask = np.load('../mask.npy')
reshaper = ut.Reshaper(mask)
W = sparse.load_npz('W.npz')
mask.shape, reshaper.surviving_coords, W.shape

In [None]:
Ms = xr.open_dataarray('projection_patterns_T14_tau0_y8000_fold4.nc')
Ms

In [None]:
ss = xr.open_dataset('Skill_T14_tau0_percent5.nc')['skill_GA'].sel(years=Ms['years'], fold=Ms['fold'])
ss

In [None]:
sel = ss.sel(reg_c=[1e-7,1,1e2])

# sel = ds.sel(T=14, tau=10, years=8000,
#              reg_c = [0.1,1,10,]
#             )

S = r'\mathcal{S}'

pretits = ['Too low', 'Right amount of', 'Too high']

for i,epsilon in enumerate(sel['reg_c'].data):
    proj = Ms.sel(reg_c=epsilon).data.squeeze()
    fig = cplt.mfp(LON,LAT,reshaper.inv_reshape(proj), one_fig_layout=130, figsize=(15,5), fig_num=8+i,
                   titles=['Temperature', 'Geopotential', 'Soil moisture'],
                  )[0].get_figure()
    
    eps = '10^{%d}' %np.log10(epsilon)
    h2 = proj @ W @ proj
    fig.suptitle(fr"{pretits[i]} regularization: $\epsilon = {eps}$, $\sqrt{{H_2}} = {np.sqrt(h2):.0f}$, ${S} = {sel.sel(reg_c=epsilon).data.item():.3f}${' '*25}")
    # fig.tight_layout()
    
    # fig.savefig(f'{HOME}/M_eps1e{exp:.0f}.pdf')

## Table 1

In [None]:
ds = xr.open_dataset('Skill_T14_tau0_percent5.nc')
ds

In [None]:
years = [8000, 4000, 2000, 1000, 500,200]
reg_c=[1e-2,1e-1,1,10,100]

dsm = ds.mean('fold').sel(years=years, reg_c=reg_c)
dsm

In [None]:
eps = [r'$10^{%d}$' %np.log10(e) for e in reg_c]
yr = tbl.frmt(0.9*dsm['years'].data, 0)
xlabel=r'$\epsilon$'
ylabel='years of training'

tbl.table(dsm['skill_GA'].data.T, eps, yr,
          cmap=plt.cm.summer,
          xlabel=xlabel,
          ylabel=ylabel,
          title='Normalized log score',
          text_digits=2)

vals = 1 - dsm['skill_GA']/dsm['skill_CNN']
tit = r'$1 - \mathcal{S}/\mathcal{S}_{CNN}$'

mx = np.nanmax(np.abs(vals))*1.1
print(mx)
norm = matplotlib.colors.TwoSlopeNorm(0, -mx, mx)

tbl.table(vals.data.T, eps, yr,
          norm=norm, cmap=plt.cm.BrBG_r,
          xlabel=xlabel,
          ylabel=ylabel,
          text_digits=2,
          title=tit)

_ = None

### Table 1a (tex)

In [None]:
_ = tbl.tex_table(dsm['skill_GA'].data.T, eps, yr,
                  cmap=plt.cm.summer,
                  xlabel=xlabel,
                  ylabel=ylabel,
                  title='Normalized log score',
                  text_digits=2,
                  close_left=False
                 )
print(_)

### Table 1b (tex)

In [None]:
_ = tbl.tex_table(vals.data.T,
                  eps, 
                  [' ']*len(yr),
                  norm=norm, cmap=plt.cm.BrBG_r,
                  xlabel=xlabel,
                  ylabel=None,
                  text_digits=2,
                  close_left=False,
                  title=tit)
print(_)

## Figure 6

In [None]:
ds_GA = xr.open_dataset('Skill-GA_T14_tau0_y8000.nc')
ds_GA

In [None]:
ds_CNN = xr.open_dataset('Skill-CNN_T14_tau0_y8000.nc')
ds_CNN

In [None]:
def percent2a(p):
    return np.interp(p, ds_GA['percent'].data, ds_GA['a'].data.squeeze())

In [None]:
plt.close(6)
fig,ax = plt.subplots(num=6,figsize=(9,6))

uplt.errorband(ds_GA['a'].data.squeeze(), uplt.xr_avg(ds_GA['skill'], 'fold').data.squeeze(), marker=None, label='GA')
uplt.errorband(percent2a(ds_CNN['percent'].data), uplt.xr_avg(ds_CNN['skill'], 'fold').data.squeeze(), label='CNN', color='gray')

# plt.axhline(0, color='black', linestyle='dashed', label='climatology')


plt.legend(loc='lower left')
plt.xlabel(r'$a$ [K]')
plt.ylabel('Normalized log score')
# plt.title(fr'${T = }, \tau = {tau}, y = {years}$')

ax2 = ax.secondary_xaxis('top')
pticks = np.array([5,1,0.2])
ax2.set_xticks(percent2a(pticks))
ax2.set_xticklabels(pticks)
ax2.set_xlabel(r'$p$ [%]')

fig.tight_layout()

# fig.savefig(f'{HOME}/Svpercent.pdf')

## Table 4

In [None]:
ds = xr.open_dataset('Skill_percent5_y8000_epsilon1.nc')
ds

In [None]:
T = [1,7,14,30]
tau = [0,5,10,15,20,30]

dsm = ds.sel(T=T,tau=tau).mean('fold')
dsm

In [None]:
xlabel = r'$\tau$ [days]'
ylabel = r'$T$ [days]'

tbl.table(dsm['skill_GA'].data, dsm['tau'].data, dsm['T'].data,
          cmap=plt.cm.summer, vmax=0.6,
          xlabel=xlabel, ylabel=ylabel, title='Normalized log score',
         )

vals = 1 - dsm['skill_GA']/dsm['skill_CNN']

mx = np.nanmax(np.abs(vals))*1.1
print(mx)
norm = matplotlib.colors.TwoSlopeNorm(0, -mx, mx)

tbl.table(vals.data, dsm['tau'].data, dsm['T'].data,
          norm=norm, cmap=plt.cm.BrBG_r,
          xlabel=xlabel, ylabel=ylabel, title=r'$1 - \mathcal{S}/\mathcal{S}_{CNN}$',
         )

_ = None

### Table 4a (tex)

In [None]:
_ = tbl.tex_table(dsm['skill_GA'].data, dsm['tau'].data, dsm['T'].data,
                  cmap=plt.cm.summer, vmax=0.6,
                  xlabel=xlabel,
                  ylabel=ylabel,
                  title='norm log score',
                  close_left=False)
print(_)

### Table 4b (tex)

In [None]:
_ = tbl.tex_table(vals.data, dsm['tau'].data, [' ']*len(dsm['T']),
                  norm=norm, cmap=plt.cm.BrBG_r,
                  xlabel=xlabel,
                  ylabel=None,
                  title=r'$1 - \mathcal{S}/\mathcal{S}_{CNN}$',
                  close_left=False
                 )
print(_)

## Table 6

In [None]:
ds = xr.open_dataset('Skill-GA_T14_tau0_y8000.nc')['skill']
ds

In [None]:
ds_Z = xr.open_dataarray('Skill-GA_T14_tau0_percent5_y8000_epsilon1-Z.nc')
ds_Z

In [None]:
ds_80 = xr.open_dataarray('Skill-GA_percent5_y80_epsilonbest.nc')
ds_80

In [None]:
ds_80Z = xr.open_dataarray('Skill-GA_percent5_y80_epsilonbest-Z.nc')
ds_80Z

In [None]:
s = uplt.xr_avg(ds.sel(percent=5), 'fold').data.item()
s_Z = uplt.xr_avg(ds_Z, 'fold').data.item()
s_80 = uplt.xr_avg(ds_80.sel(T=14,tau=0), 'fold').data.item()
s_80Z = uplt.xr_avg(ds_80Z.sel(T=14,tau=0), 'fold').data.item()

In [None]:
tb = r'''
\begin{tabular}{c|cc}
    & \multicolumn{2}{c}{Predictor fields} \\
    years of data &  $T_\mathrm{2m}, Z, S$ & $Z$ \\
    \hline
    8000 & %s & %s \\
    80 & %s & %s
\end{tabular}
''' % tuple([f'${u:uL}$' for u in (s, s_Z, s_80, s_80Z)])

print(tb)