In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget
matplotlib.rc('font', size=18)
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

import xarray as xr

import sys
sys.path.append('../../Climate-Learning/')

import general_purpose.uplotlib as uplt
import general_purpose.cartopy_plots as cplt
import general_purpose.utilities as ut

HOME = '../'

In [None]:
lon = np.load('../common/lon.npy')
lat = np.load('../common/lat.npy')
LON, LAT = np.meshgrid(lon,lat)

In [None]:
mfp_kwargs = dict(one_fig_layout=120, figsize=(10,5),
                  projections=[cplt.ccrs.Orthographic(central_latitude=90), cplt.ccrs.PlateCarree()],
                  extents=[None, (-5, 10, 39, 55)],
                  titles=['Geopotential height', 'Soil moisture'],
                 )
cmaps=['RdBu_r', 'BrBG']

## Optimal input of CNN

In [None]:
ds = xr.open_dataset('optimal-input.nc')
ds

In [None]:
uplt.avg(ds.output.data)

In [None]:
A_te = xr.open_dataarray('../common/A_te.nc')
A_te

In [None]:
threshold = np.load('../common/threshold.npy').item()
threshold

In [None]:
plt.close(1)
fig, ax = plt.subplots(num=1, figsize=(18,6))

bins = np.arange(int(threshold), 16, 0.1)

plt.hist(A_te.data[A_te.data >= threshold], bins=bins, alpha=0.5, density=True, label='tail of heatwave data')
plt.hist(ds.output.data, bins=bins, density=True, alpha=0.5, label= r'$\mu_{CNN}(S)$')
plt.axvline(threshold, color='red', label=r'$a_5$')
plt.legend()

plt.xlabel('heatwave amplitude [K]')

fig.tight_layout()

fig.savefig(f'{HOME}/CNN-OI-hist.pdf')

### Composite of optimal inputs

In [None]:
cplt.mfp(LON, LAT, ds.optimal_input_mean.data, **mfp_kwargs, cmaps=cmaps, fig_num=2)
cplt.mfp(LON, LAT, ds.optimal_input_std.data, **mfp_kwargs, cmaps='Greys', fig_num=3, colorbar='disabled', vmin=0)

### Make it a single figure

In [None]:
kw = dict(mode='pcolormesh',
        greenwich=True,
        draw_gridlines=False, draw_labels=False,
         )

plt.close(1)
fig = plt.figure(num=1, figsize=(10, 10))


# plot geoplots
projs = np.concatenate([ds.optimal_input_mean.data, ds.optimal_input_std.data], axis=-1)

for i in range(4):
    ax = fig.add_subplot(221 + i, projection=mfp_kwargs['projections'][i%2])

    if i < 2:
        vmin = None
        _mx = np.max(np.abs(projs[...,i]))
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=0., vmin=-_mx, vmax=_mx)
        cmap = cmaps[i%2]
        title = mfp_kwargs['titles'][i%2] + ' (mean)'
    else:
        _norm = None
        vmin = 0
        cmap = 'Greys'
        title = mfp_kwargs['titles'][i%2] + ' (std)'


    cplt.geo_plotter(ax, LON, LAT, projs[...,i], cmap=cmap, norm=_norm, title=title, vmin=vmin, **kw)

    if i%2:
        ax.set_extent(mfp_kwargs['extents'][1])


fig.tight_layout(w_pad=0)

fig.savefig(f'{HOME}/CNN-OI.pdf')

## Orthogonal Optimal Input

In [None]:
dso = xr.open_dataset('orth-optimal-input.nc')
dso

In [None]:
uplt.avg(dso.output.data) # here output is mu_CNN - mu_GA

In [None]:
plt.close(1)
fig, ax = plt.subplots(num=1, figsize=(18,6))

bins = np.arange(int(threshold), 16, 0.1)

plt.hist(A_te.data[A_te.data >= threshold], bins=bins, density=True, alpha=0.5, label='tail of heatwave data')
plt.hist(dso.output.data + dso.ga_output.data, bins=bins, density=True, alpha=0.5, label= r'$\mu_{CNN}(S)$')
plt.hist(dso.output.data, bins=bins, density=True, alpha=0.5, label= r'$\mu_{CNN}(S) - \mu_{GA}(S)$')
plt.axvline(threshold, color='red', label=r'$a_5$')
plt.legend()

plt.xlabel('heatwave amplitude [K]')

fig.tight_layout()

fig.savefig(f'{HOME}/CNN-OI-orth-hist.pdf')

### Composite

In [None]:
kw = dict(mode='pcolormesh',
        greenwich=True,
        draw_gridlines=False, draw_labels=False,
         )

plt.close(1)
fig = plt.figure(num=1, figsize=(10, 10))


# plot geoplots
projs = np.concatenate([dso.optimal_input_mean.data, dso.optimal_input_std.data], axis=-1)

for i in range(4):
    ax = fig.add_subplot(221 + i, projection=mfp_kwargs['projections'][i%2])

    if i < 2:
        vmin = None
        _mx = np.max(np.abs(projs[...,i]))
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=0., vmin=-_mx, vmax=_mx)
        cmap = cmaps[i%2]
        title = mfp_kwargs['titles'][i%2] + ' (mean)'
    else:
        _norm = None
        vmin = 0
        cmap = 'Greys'
        title = mfp_kwargs['titles'][i%2] + ' (std)'


    cplt.geo_plotter(ax, LON, LAT, projs[...,i], cmap=cmap, norm=_norm, title=title, vmin=vmin, **kw)

    if i%2:
        ax.set_extent(mfp_kwargs['extents'][1])


fig.tight_layout(w_pad=0)

# fig.savefig(f'{HOME}/CNN-OI-orth.pdf')

## Make it a single figure with the two

In [None]:
kw = dict(mode='pcolormesh',
        greenwich=True,
        draw_gridlines=False, draw_labels=False,
         )

plt.close(1)
fig = plt.figure(num=1, figsize=(20, 10))


# plot geoplots
projs = np.concatenate([ds.optimal_input_mean.data, dso.optimal_input_mean.data, ds.optimal_input_std.data, dso.optimal_input_std.data], axis=-1)

for i in range(8):
    ax = fig.add_subplot(241 + i, projection=mfp_kwargs['projections'][i%2])

    if i < 4:
        vmin = None
        _mx = np.max(np.abs(projs[...,i]))
        _norm = matplotlib.colors.TwoSlopeNorm(vcenter=0., vmin=-_mx, vmax=_mx)
        cmap = cmaps[i%2]
        title = mfp_kwargs['titles'][i%2] + ' (mean)'
    else:
        _norm = None
        vmin = 0
        cmap = 'Greys'
        title = mfp_kwargs['titles'][i%2] + ' (std)'


    cplt.geo_plotter(ax, LON, LAT, projs[...,i], cmap=cmap, norm=_norm, title=title, vmin=vmin, **kw)

    if i%2:
        ax.set_extent(mfp_kwargs['extents'][1])


fig.suptitle('Optimal input' + ' '*60 + 'Orthogonal optimal input')

fig.tight_layout(w_pad=0)

# fig.savefig(f'{HOME}/CNN-OI-all.pdf')

## Roughness and L2 norm of the real snapshots

In [None]:
ds = xr.open_dataset('../Te-l2-roughness.nc')
ds

In [None]:
l2s = ds.l2.data

plt.close(1)
fig,ax = plt.subplots(num=1,figsize=(9,6))

plt.hist(l2s, bins=50, density=True)

plt.axvline(np.quantile(l2s, 0.05), color='red', linestyle='dashed', label='5% and 95%')
plt.axvline(np.quantile(l2s, 0.95), color='red', linestyle='dashed')
plt.axvline(np.mean(l2s), color='black', linestyle='dashed', label='mean')
plt.axvline(np.median(l2s), color='lime', linestyle='dashed', label='median')

plt.xlabel('L2 norm')
plt.legend()

fig.tight_layout()

# fig.savefig(f'{HOME}/l2-hist.pdf')

In [None]:
roughs = ds.roughness.data

plt.close(2)
fig,ax = plt.subplots(num=2,figsize=(9,6))

plt.hist(roughs, bins=50, density=True)

plt.axvline(np.quantile(roughs, 0.05), color='red', linestyle='dashed', label='5% and 95%')
plt.axvline(np.quantile(roughs, 0.95), color='red', linestyle='dashed')
plt.axvline(np.mean(roughs), color='black', linestyle='dashed', label='mean')
plt.axvline(np.median(roughs), color='lime', linestyle='dashed', label='median')

plt.xlabel('roughness ($\sqrt{H_2}$)')
plt.legend()

fig.tight_layout()

# fig.savefig(f'{HOME}/h2-hist.pdf')

### Make it a single figure

In [None]:
plt.close(1)
fig, axs = plt.subplots(1,2, num=1, figsize=(18,6))

axs[0].hist(l2s, bins=50, density=True)

axs[0].axvline(np.quantile(l2s, 0.05), color='red', linestyle='dashed', label='5% and 95%')
axs[0].axvline(np.quantile(l2s, 0.95), color='red', linestyle='dashed')
axs[0].axvline(np.mean(l2s), color='black', linestyle='dashed', label='mean')
axs[0].axvline(np.median(l2s), color='lime', linestyle='dashed', label='median')

axs[0].set_xlabel('L2 norm')
axs[0].legend()


axs[1].hist(roughs, bins=50, density=True)

axs[1].axvline(np.quantile(roughs, 0.05), color='red', linestyle='dashed', label='5% and 95%')
axs[1].axvline(np.quantile(roughs, 0.95), color='red', linestyle='dashed')
axs[1].axvline(np.mean(roughs), color='black', linestyle='dashed', label='mean')
axs[1].axvline(np.median(roughs), color='lime', linestyle='dashed', label='median')

axs[1].set_xlabel('roughness ($\sqrt{H_2}$)')
axs[1].legend()

fig.tight_layout()

fig.savefig(f'{HOME}/l2-h2-hist.pdf')