### Offline validation tendency profiles

In [None]:
import os
import xarray as xr
import numpy as np
import fsspec
from matplotlib import pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 14})

In [None]:
CASES = dict(
    tquvr_rf=['gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/rf/offline_diags/postphysics_ML_tendencies/offline_diagnostics.nc'],
    tqr_rf=['gs://vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/rf/offline_diags/postphysics_ML_tendencies/offline_diagnostics.nc'],
    tquvr_nn=[
        'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/offline_diags/dq1-dq2/offline_diagnostics.nc',
        'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/offline_diags/dqu-dqv/offline_diagnostics.nc'
    ]
)

OUTPUTDIR = 'figures'

In [None]:
KEEP_VARS = [
    'dQ1',
    'dQ2',
    'dQu',
    'dQv',
    'pressure_level-mse-dQ1-predict_vs_target',
    'pressure_level-mse-dQ2-predict_vs_target',
    'pressure_level-mse-dQu-predict_vs_target',
    'pressure_level-mse-dQv-predict_vs_target',
    'pressure_level-mse-dQ1-mean_vs_target',
    'pressure_level-mse-dQ2-mean_vs_target',
    'pressure_level-mse-dQu-mean_vs_target',
    'pressure_level-mse-dQv-mean_vs_target'
]


def open_diags(path):
    with fsspec.open(path, 'rb') as f:
        ds = xr.open_dataset(f).load()
    return ds

datasets = []
for case, paths in CASES.items():
    print(case)
    if len(paths) == 1:
        ds = open_diags(paths[0]).assign_coords({'cases': case})
    else:
        case_datasets = []
        for path in paths:
            case_datasets.append(open_diags(path))
        ds = xr.merge(case_datasets, compat='override')
    for variable in ['dQu', 'dQv']:
        if variable not in ds:
            ds[variable] = xr.full_like(ds['dQ1'], np.nan)
    for variable in [
        'pressure_level-mse-dQu-predict_vs_target',
        'pressure_level-mse-dQv-predict_vs_target',
        'pressure_level-mse-dQu-mean_vs_target',
        'pressure_level-mse-dQv-mean_vs_target'
    ]:
        if variable not in ds:
            ds[variable] = xr.full_like(ds['pressure_level-mse-dQ1-predict_vs_target'], np.nan)
    datasets.append(ds[KEEP_VARS].assign_coords({'cases': case}))
diags = xr.concat(datasets, dim="cases")

In [None]:
R2_VARS = {
    'dQ2': 'kg/kg/s',
    'dQ1': 'K/s',
    'dQu': 'm/s^2',
    'dQv': 'm/s^2'
}

def add_R2(ds):
    for var in R2_VARS:
        ds[f"pressure_level-R2-{var}"] = 1.0 - ds[f'pressure_level-mse-{var}-predict_vs_target']/ds[f'pressure_level-mse-{var}-mean_vs_target']
    return ds

diags = add_R2(diags)

In [None]:
def plot_R2_cases(da):
    fig, ax = plt.subplots(1, 1)
    if 'dQ2' in da.name:
        h = da.plot(y='pressure', hue='cases', yincrease=False, ax=ax, lw=2)
        plt.legend(h, ['$TquvR$-RF', '$TqR$-RF', '$TquvR$-NN'], loc=1)
    else:
        da.plot(y='pressure', hue='cases', yincrease=False, ax=ax, lw=2, add_legend=False)
    ax.set_xlim([0, 0.5])
    ax.set_ylim([1.0e5, 0])
    ax.set_ylabel('pressure [hPa]')
    ax.set_xlabel('$R^{2}$')
    ax.set_title(da.name.split('R2-')[1])
    ax.grid(axis='x')
    fig.tight_layout()
    fig.set_size_inches([6, 6])
    fig.set_dpi(150)
    fig.savefig(os.path.join(OUTPUTDIR, f"{da.name}-R2.png"), bbox_inches='tight')

In [None]:
for var in R2_VARS:
    plot_R2_cases(diags[f"pressure_level-R2-{var}"])

In [None]:
def plot_dQ_profile_cases(da, units):
    fig, ax = plt.subplots(1, 1)
    ax.plot([0, 0], [1, 79], 'k-')
    hrw, = da.sel(derivation='predict', cases='tquvr_rf').plot(y='z', yincrease=False, ax=ax, lw=2)
    hr, = da.sel(derivation='predict', cases='tqr_rf').plot(y='z', yincrease=False, ax=ax, lw=2)
    hn, = da.sel(derivation='predict', cases='tquvr_nn').plot(y='z', yincrease=False, ax=ax, lw=2)
    ht, = da.sel(derivation='target', cases='tquvr_rf').plot(y='z', yincrease=False, ax=ax, lw=2)
    plt.legend([ht, hrw, hr, hn], ['target', '$TquvR$-RF', '$TqR$-RF','$TquvR$-NN'], loc=2)
    ax.set_ylim([79, 1])
    ax.set_ylabel('model level')
    ax.set_xlabel(units)
    ax.set_title(da.name)
    ax.grid(axis='x')
    fig.tight_layout()
    fig.set_size_inches([6, 6])
    fig.set_dpi(150)
    fig.savefig(os.path.join(OUTPUTDIR, f"{da.name}-mean-profile.png"), bbox_inches='tight')

In [None]:
for var, units in R2_VARS.items():
    plot_dQ_profile_cases(
        diags[var].sel(domain='global_average'),
        units
    )