# Optimal input of CNN

In this notebook we use the post-hoc explainability method of _optimal input_ or _backward optimization_

This method can be used as a local (i.e. data-point-wise) explainbility method, though we focus on the aggregation of the results to get dataset-level explanations

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget
matplotlib.rc('font', size=18)
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

import xarray as xr

import sys
sys.path.append('../../Climate-Learning/')

import general_purpose.uplotlib as uplt
import general_purpose.cartopy_plots as cplt
import general_purpose.utilities as ut

HOME = '../'

In [None]:
lon = np.load('../common/lon.npy')
lat = np.load('../common/lat.npy')
LON, LAT = np.meshgrid(lon,lat)

In [None]:
# useful kwargs for plotting geo-data

mfp_kwargs = dict(one_fig_layout=120, figsize=(10,5),
                  projections=[cplt.ccrs.Orthographic(central_latitude=90), cplt.ccrs.PlateCarree()],
                  # projections=[cplt.ccrs.Mollweide(), cplt.ccrs.PlateCarree()],
                  extents=[None, (-5, 10, 39, 55)],
                  titles=['Geopotential height', 'Soil moisture'],
                 )
cmaps=['RdBu_r', 'BrBG']

## Example of non-regularized optimization

In [None]:
ds = xr.open_dataset('non-regularized-history.nc')
ds

In [None]:
Y_pred = uplt.ufloatify(ds['predicted mean'].data, ds['predicted std'].data)

plt.close(1)
fig, ax = plt.subplots(num=1, figsize=(9,6))

uplt.errorband(np.arange(Y_pred.shape[0]), Y_pred)

plt.xlabel('Step')
plt.ylabel('Model prediction')

fig.tight_layout()

In [None]:
optim = ds['input'].data

for i in range(optim.shape[0]):
    _ = cplt.mfp(LON, LAT, optim[i], **mfp_kwargs, cmaps=cmaps, fig_num=2+i)
    fig = _[0].get_figure()
    step = ds['step-coarse'].data[i]
    fig.suptitle(f'Step {step}: $\hat{{A}} = {Y_pred[step]:uL}$ K' + (f'; ($A = {ds.attrs["A true of seed"]:.2f}$ K)' if i==0 else ''))
    fig.tight_layout()
    # fig.savefig(f'{HOME}/non-reg-optim-{step}.pdf')

## Optimization over the full test dataset

The data is obtained by running script `all_optimal_input.py`

In [None]:
dss = xr.open_dataset('optimal-input.nc')
dss

In [None]:
# ds = dss.sel(fold=0, drop=True)
ds = dss.mean('fold')
ds

In [None]:
physical_mask = ds.optimal_input_std.data > 0
reshaper = ut.Reshaper(physical_mask,fill_value=np.nan)
assert reshaper.surviving_coords == 27424

In [None]:
uplt.avg(ds.output.data)

In [None]:
A_te = xr.open_dataarray('../common/A_te.nc')
A_te

In [None]:
threshold = np.load('../common/threshold.npy').item()
threshold

In [None]:
matplotlib.rc('font', size=22)
plt.close(1)
fig, ax = plt.subplots(num=1, figsize=(18,6))

bins = np.arange(int(threshold), 16, 0.1)

plt.hist(A_te.data[A_te.data >= threshold], bins=bins, alpha=0.5, density=True, label='tail of heatwave data')
plt.hist(ds.output.data, bins=bins, density=True, alpha=0.5, label= r'$\mu_{CNN}(S)$')
plt.axvline(threshold, color='red', label=r'$a_5$')
plt.legend()

plt.xlabel('heatwave amplitude [K]')

fig.tight_layout()

# fig.savefig(f'{HOME}/CNN-OI-hist.pdf')

### Composite of optimal inputs

In [None]:
cplt.mfp(LON, LAT, ds.optimal_input_mean.data, **mfp_kwargs, cmaps=cmaps, fig_num=2)
# cplt.mfp(LON, LAT, ds.optimal_input_std.data, **mfp_kwargs, cmaps='Greys', fig_num=3, colorbar='disabled', vmin=0)

norm = matplotlib.colors.TwoSlopeNorm(vcenter=1, vmin=0., vmax=3)
std_ = reshaper.inv_reshape(reshaper.reshape(ds.optimal_input_std.data)) # set to nan the values outside the physical mask (soil moisture outside France)
cplt.mfp(LON, LAT, std_, **mfp_kwargs, norm=norm, cmaps='RdGy_r', fig_num=4, extend='max')

norm = matplotlib.colors.TwoSlopeNorm(vcenter=0, vmin=-2, vmax=2)
kurt_ = reshaper.inv_reshape(reshaper.reshape(ds.optimal_input_kurtosis.data)) # set to nan the values outside the physical mask (soil moisture outside France)
cplt.mfp(LON, LAT, kurt_, **mfp_kwargs, norm=norm, cmaps='PuOr', fig_num=5, extend='both')

#### Make it a single figure

In [None]:
matplotlib.rc('font', size=18)
kw = dict(mode='pcolormesh',
        greenwich=True,
        draw_gridlines=False, draw_labels=False,
         )

plt.close(1)
fig = plt.figure(num=1, figsize=(10, 10))


# plot geoplots
projs = np.concatenate([ds.optimal_input_mean.data, reshaper.inv_reshape(reshaper.reshape(ds.optimal_input_std.data))], axis=-1)

for i in range(4):
    ax = fig.add_subplot(221 + i, projection=mfp_kwargs['projections'][i%2])

    if i < 2:
        _mx = np.max(np.abs(projs[...,i]))
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=0., vmin=-_mx, vmax=_mx)
        cmap = cmaps[i%2]
        extend = 'both'
        title = mfp_kwargs['titles'][i%2] + ' (mean)'
    else:
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=1., vmin=0, vmax=3)
        cmap = 'RdGy_r'
        extend = 'max'
        title = mfp_kwargs['titles'][i%2] + ' (std)'


    cplt.geo_plotter(ax, LON, LAT, projs[...,i], cmap=cmap, norm=_norm, title=title, extend=extend, **kw)

    if i%2:
        ax.set_extent(mfp_kwargs['extents'][1])


fig.tight_layout(w_pad=0)

# fig.savefig(f'{HOME}/CNN-OI.pdf')

## Orthogonal Optimal Input

Here we analyze the results of what happens when we add a regularization term that forces the optimization to move orthogonal to the direction of the Gaussian Approximation projection pattern.

The data is obtained by running script `all_optimal_input_orth.py`

In [None]:
dsos = xr.open_dataset('orth-optimal-input.nc')
dsos

In [None]:
# dso = dsos.sel(fold=0, drop=True)
dso = dsos.mean('fold')

In [None]:
uplt.avg(dso.output.data) # here output is mu_CNN - mu_GA

In [None]:
matplotlib.rc('font', size=22)
plt.close(1)
fig, ax = plt.subplots(num=1, figsize=(18,6))

bins = np.arange(int(threshold), 16, 0.1)

plt.hist(A_te.data[A_te.data >= threshold], bins=bins, density=True, alpha=0.5, label='tail of heatwave data')
plt.hist(dso.output.data + dso.ga_output.data, bins=bins, density=True, alpha=0.5, label= r'$\mu_{CNN}(S)$')
plt.hist(dso.output.data, bins=bins, density=True, alpha=0.5, label= r'$\mu_{CNN}(S) - \mu_{GA}(S)$')
plt.axvline(threshold, color='red', label=r'$a_5$')
plt.legend()

plt.xlabel('heatwave amplitude [K]')

fig.tight_layout()

# fig.savefig(f'{HOME}/CNN-OI-orth-hist.pdf')

### Composite

In [None]:
cplt.mfp(LON, LAT, dso.optimal_input_mean.data, **mfp_kwargs, cmaps=cmaps, fig_num=2)
# cplt.mfp(LON, LAT, ds.optimal_input_std.data, **mfp_kwargs, cmaps='Greys', fig_num=3, colorbar='disabled', vmin=0)

norm = matplotlib.colors.TwoSlopeNorm(vcenter=1, vmin=0., vmax=3)
std_ = reshaper.inv_reshape(reshaper.reshape(dso.optimal_input_std.data)) # set to nan the values outside the physical mask (soil moisture outside France)
cplt.mfp(LON, LAT, std_, **mfp_kwargs, norm=norm, cmaps='RdGy_r', fig_num=4, extend='max')

norm = matplotlib.colors.TwoSlopeNorm(vcenter=0, vmin=-2, vmax=2)
kurt_ = reshaper.inv_reshape(reshaper.reshape(dso.optimal_input_kurtosis.data)) # set to nan the values outside the physical mask (soil moisture outside France)
cplt.mfp(LON, LAT, kurt_, **mfp_kwargs, norm=norm, cmaps='PuOr', fig_num=5, extend='both')

In [None]:
matplotlib.rc('font', size=18)
kw = dict(mode='pcolormesh',
        greenwich=True,
        draw_gridlines=False, draw_labels=False,
         )

plt.close(1)
fig = plt.figure(num=1, figsize=(10, 10))


# plot geoplots
projs = np.concatenate([dso.optimal_input_mean.data, reshaper.inv_reshape(reshaper.reshape(dso.optimal_input_std.data))], axis=-1)

for i in range(4):
    ax = fig.add_subplot(221 + i, projection=mfp_kwargs['projections'][i%2])

    if i < 2:
        _mx = np.max(np.abs(projs[...,i]))
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=0., vmin=-_mx, vmax=_mx)
        cmap = cmaps[i%2]
        extend = 'both'
        title = mfp_kwargs['titles'][i%2] + ' (mean)'
    else:
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=1., vmin=0, vmax=3)
        cmap = 'RdGy_r'
        extend = 'max'
        title = mfp_kwargs['titles'][i%2] + ' (std)'


    cplt.geo_plotter(ax, LON, LAT, projs[...,i], cmap=cmap, norm=_norm, title=title, extend=extend, **kw)

    if i%2:
        ax.set_extent(mfp_kwargs['extents'][1])


fig.tight_layout(w_pad=0)

# fig.savefig(f'{HOME}/CNN-OI-orth.pdf')

### Make it a single figure with the two

(Combine standard and orthogonal optimization)

In [None]:
kw = dict(mode='pcolormesh',
        greenwich=True,
        draw_gridlines=False, draw_labels=False,
         )

plt.close(1)
fig = plt.figure(num=1, figsize=(20, 10))


# plot geoplots
projs = np.concatenate([ds.optimal_input_mean.data,
                        dso.optimal_input_mean.data,
                        reshaper.inv_reshape(reshaper.reshape(ds.optimal_input_std.data)),
                        reshaper.inv_reshape(reshaper.reshape(dso.optimal_input_std.data))], axis=-1)

for i in range(8):
    ax = fig.add_subplot(241 + i, projection=mfp_kwargs['projections'][i%2])

    if i < 4:
        _mx = np.max(np.abs(projs[...,i]))
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=0., vmin=-_mx, vmax=_mx)
        cmap = cmaps[i%2]
        extend = 'both'
        title = mfp_kwargs['titles'][i%2] + ' (mean)'
    else:
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=1., vmin=0, vmax=3)
        cmap = 'RdGy_r'
        extend = 'max'
        title = mfp_kwargs['titles'][i%2] + ' (std)'


    cplt.geo_plotter(ax, LON, LAT, projs[...,i], cmap=cmap, norm=_norm, title=title, extend=extend, **kw)

    if i%2:
        ax.set_extent(mfp_kwargs['extents'][1])


fig.suptitle('Optimal input' + ' '*60 + 'Orthogonal optimal input')

fig.tight_layout(w_pad=0)

fig.savefig(f'{HOME}/CNN-OI-all.pdf')

## Roughness and L2 norm of the real snapshots

In [None]:
ds = xr.open_dataset('../Te-l2-roughness.nc')
ds

In [None]:
l2s = ds.l2.data

plt.close(1)
fig,ax = plt.subplots(num=1,figsize=(9,6))

plt.hist(l2s, bins=50, density=True)

plt.axvline(np.quantile(l2s, 0.05), color='red', linestyle='dashed', label='5% and 95%')
plt.axvline(np.quantile(l2s, 0.95), color='red', linestyle='dashed')
plt.axvline(np.mean(l2s), color='black', linestyle='dashed', label='mean')
plt.axvline(np.median(l2s), color='lime', linestyle='dashed', label='median')

plt.xlabel('L2 norm')
plt.legend()

fig.tight_layout()

# fig.savefig(f'{HOME}/l2-hist.pdf')

In [None]:
roughs = ds.roughness.data

plt.close(2)
fig,ax = plt.subplots(num=2,figsize=(9,6))

plt.hist(roughs, bins=50, density=True)

plt.axvline(np.quantile(roughs, 0.05), color='red', linestyle='dashed', label='5% and 95%')
plt.axvline(np.quantile(roughs, 0.95), color='red', linestyle='dashed')
plt.axvline(np.mean(roughs), color='black', linestyle='dashed', label='mean')
plt.axvline(np.median(roughs), color='lime', linestyle='dashed', label='median')

plt.xlabel('roughness ($\sqrt{H_2}$)')
plt.legend()

fig.tight_layout()

# fig.savefig(f'{HOME}/h2-hist.pdf')

### Make it a single figure

In [None]:
matplotlib.rc('font', size=22)
plt.close(1)
fig, axs = plt.subplots(1,2, num=1, figsize=(18,6))

axs[0].hist(l2s, bins=50, density=True)

axs[0].axvline(np.quantile(l2s, 0.05), color='red', linestyle='dashed', label='5% and 95%')
axs[0].axvline(np.quantile(l2s, 0.95), color='red', linestyle='dashed')
axs[0].axvline(np.mean(l2s), color='black', linestyle='dashed', label='mean')
axs[0].axvline(np.median(l2s), color='lime', linestyle='dashed', label='median')

axs[0].set_xlabel('L2 norm')
axs[0].legend()


axs[1].hist(roughs, bins=50, density=True)

axs[1].axvline(np.quantile(roughs, 0.05), color='red', linestyle='dashed', label='5% and 95%')
axs[1].axvline(np.quantile(roughs, 0.95), color='red', linestyle='dashed')
axs[1].axvline(np.mean(roughs), color='black', linestyle='dashed', label='mean')
axs[1].axvline(np.median(roughs), color='lime', linestyle='dashed', label='median')

axs[1].set_xlabel('roughness ($\sqrt{H_2}$)')
axs[1].legend()

fig.tight_layout()

fig.savefig(f'{HOME}/l2-h2-hist.pdf')