### time-mean dQ2 target vs RF/NN-ens prediction maps 

In [None]:
import fv3viz
import xarray as xr
import fsspec
import intake
import numpy as np
from dataclasses import dataclass
from typing import Sequence, Mapping, Tuple
import os
import string
from vcm.catalog import catalog as CATALOG
from vcm.fv3.metadata import standardize_fv3_diagnostics
from matplotlib import pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 8})
from cartopy import crs as ccrs
import cftime
from dask.diagnostics import ProgressBar

In [None]:
@dataclass
class OutputVariable:
    file_name: str
    long_name: str
    units: str
    ds_name: str=None
        
    @property
    def name(self):
        return self.file_name if not self.ds_name else self.ds_name

@dataclass
class MLOfflineDiags:
    name: str
    offline_diags: Tuple[str, Mapping[str, Sequence[OutputVariable]]]

In [None]:
dQ1 = OutputVariable('column_integrated_dQ1', 'column heating ($\Delta Q_T$)', 'W/m^2')
dQ2 = OutputVariable('column_integrated_dQ2', 'column moistening ($\Delta Q_q$)', 'mm/day')
downward_shortwave = OutputVariable('override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface', 'downward shortwave', 'W/m^2')
downward_longwave = OutputVariable('override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface', 'downward longwave', 'W/m^2')
net_shortwave = OutputVariable('override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface', 'net shortwave', 'W/m^2')

RUNS = [
    MLOfflineDiags(
        '$TqR$-RF',
        (
            'gs://vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/rf/offline_diags',
            {
                'postphysics_ML_tendencies': [dQ1, dQ2],
                'prephysics_ML_surface_flux': [downward_shortwave, downward_longwave, net_shortwave]
            },
        ),
    ),
    MLOfflineDiags(
        '$TqR$-NN',
        (
            'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/offline_diags',
            {
                'dq1-dq2': [dQ1, dQ2],
                'surface-rad-rectified': [downward_shortwave, downward_longwave, net_shortwave]
            }
        )
    ),
]

SECONDS_PER_DAY=86400
OUTDIR = 'figures'

In [None]:
def open_diags(path):
    with fsspec.open(path, 'rb') as f, ProgressBar():
        ds = xr.open_dataset(f).load()
    return ds

def time_average(ds, time_slice):
    ds = ds.sel(time=time_slice).mean(dim='time')
    return ds.compute()

pred_datasets = []
target_ds = {}
for run in RUNS:
    run_pred_ds = {}
    offline_root = run.offline_diags[0]
    model_mapping = run.offline_diags[1]
    for model in model_mapping.keys():
        path = os.path.join(offline_root, model, 'offline_diagnostics.nc')
        print(path)
        diags = open_diags(path)
        for variable in model_mapping[model]:
            if variable.name not in target_ds:
                target_ds[variable.name] = (
                    diags[variable.file_name]
                    .sel(derivation='target')
                    .drop_vars('derivation')
                ).assign_attrs({
                'long_name': variable.long_name,
                'units': variable.units
            })
            run_pred_ds[variable.name + '_bias'] = (
                diags[variable.file_name]
                .sel(derivation='predict')
                .drop_vars('derivation') - target_ds[variable.name]
            ).assign_attrs({
                'long_name': variable.long_name,
                'units': variable.units
            })
            run_pred_ds[variable.name] = (
                diags[variable.file_name]
                .sel(derivation='predict')
                .drop_vars('derivation')
            ).assign_attrs({
                'long_name': variable.long_name,
                'units': variable.units
            })
    for variable in [downward_shortwave, downward_longwave, net_shortwave]:
        if variable.name not in run_pred_ds:
            run_pred_ds[variable.name] = xr.full_like(run_pred_ds['column_integrated_dQ1'], np.nan)
            run_pred_ds[variable.name + '_bias'] = xr.full_like(run_pred_ds['column_integrated_dQ1'], np.nan)
    
    pred_datasets.append(xr.Dataset(run_pred_ds).expand_dims({'derivation': [run.name]}))
pred_ds = xr.concat(pred_datasets, dim='derivation')
target_ds = xr.Dataset(target_ds).expand_dims({'derivation': ['target']})

In [None]:
ds = xr.merge([pred_ds, target_ds]).reindex({'derivation': ['target', '$TqR$-RF', '$TqR$-NN']})

In [None]:
def weighted_mean(ds, weights):
    return ((ds*weights).sum(dim=['x', 'y', 'tile']))/(weights.sum(dim=['x', 'y', 'tile']))

def weighted_rms(ds, weights):
    return np.sqrt(
        (((ds**2)*weights).sum(dim=['x', 'y', 'tile']))/(weights.sum(dim=['x', 'y', 'tile']))
    )

def weighted_mse(ds, weights):
    return (((ds**2)*weights).sum(dim=['x', 'y', 'tile']))/(weights.sum(dim=['x', 'y', 'tile']))

def weighted_variance(ds, weights):
    mean = weighted_mean(ds, weights)
    return ((((ds - mean)**2)*weights).sum(dim=['x', 'y', 'tile']))/(weights.sum(dim=['x', 'y', 'tile']))

In [None]:
grid = CATALOG['grid/c48'].to_dask()
MAPPABLE_VAR_KWARGS = {
    "coord_x_center": "x",
    "coord_y_center": "y",
    "coord_x_outer": "x_interface",
    "coord_y_outer": "y_interface",
    "coord_vars": {
        "lonb": ["y_interface", "x_interface", "tile"],
        "latb": ["y_interface", "x_interface", "tile"],
        "lon": ["y", "x", "tile"],
        "lat": ["y", "x", "tile"],
    },
}

In [None]:
var = 'column_integrated_dQ2'
means = weighted_mean(ds[var], grid['area'])
rmses = weighted_rms(ds[var + '_bias'], grid['area'])
_, _, _, cbar, fg = fv3viz.plot_cube(
    fv3viz.mappable_var(xr.merge([grid, ds]), var, **MAPPABLE_VAR_KWARGS),
    col='derivation',
    col_wrap=3,
)
for i, (ax, case, mean, rmse) in enumerate(zip(fg.axes.flatten(), ds.derivation.values, means.values, rmses.values)):
    if 'target' not in case:
        ax.set_title(f"{string.ascii_lowercase[i]}) {case}\nmean: {mean:.2f}, rmse: {rmse:.2f}")
    else:
        ax.set_title(f"{string.ascii_lowercase[i]}) {case}\nmean: {mean:.2f}")
cbar.set_ticks(np.linspace(-3, 3, 7))
cbar.set_label('column moistening\n' + r'$\langle \Delta Q_q \rangle$, [mm/day]')
fg.fig.set_size_inches([7.6, 1.5])
fg.fig.savefig(f'{OUTDIR}/Figure_6_column_integrated_dQ2_offline_time_mean.eps', bbox_inches='tight')

In [None]:
def add_total_downward(ds):
    total_downward = (
        ds['override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface'] + ds['override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface']
    ).assign_attrs({'long_name': 'total downwelling surface radiation ', 'units': '$W/m^2$'})
    total_downward_bias = (
        ds['override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface_bias'] + ds['override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface_bias']
    ).assign_attrs({'long_name': 'total downwelling surface radiation bias', 'units': '$W/m^2$'})
    ds['downward_total_radiative_flux_at_surface'] = total_downward
    ds['downward_total_radiative_flux_at_surface_bias'] = total_downward_bias
    return ds

ds = add_total_downward(ds)

In [None]:
var = 'downward_total_radiative_flux_at_surface_bias'
ds_pred = ds.sel({'derivation': ['$TqR$-RF', '$TqR$-NN']})
means = weighted_mean(ds_pred[var], grid['area'])
rmss = weighted_rms(ds_pred[var], grid['area'])

_, _, _, _, fg = fv3viz.plot_cube(
    fv3viz.mappable_var(xr.merge([grid, ds_pred]), var, **MAPPABLE_VAR_KWARGS),
    col = 'derivation'
)
for i, (ax, case, mean, rms) in enumerate(zip(fg.axes.flatten(), ds_pred.derivation.values, means.values, rmss.values)):
    ax.set_title(f"{string.ascii_lowercase[i]}) {case}\nmean:{mean:.1f}, rms: {rms:0.1f}")
fg.fig.set_size_inches([7.6, 2])
fg.fig.savefig(
    f'{OUTDIR}/Figure_9_downward_total_radiative_flux_at_surface_bias_offline_time_mean.eps',
    bbox_inches='tight'
)

In [None]:
total_rad_r_squared = 1 - weighted_mse(ds_pred.downward_total_radiative_flux_at_surface_bias, grid.area)/weighted_variance(ds_pred.downward_total_radiative_flux_at_surface, grid.area)
total_rad_r_squared.values

In [None]:
# this converts matplotlib eps files to a more manageable size

!epstopdf figures/Figure_6_column_integrated_dQ2_offline_time_mean.eps
!pdftops -eps figures/Figure_6_column_integrated_dQ2_offline_time_mean.pdf
!rm figures/Figure_6_column_integrated_dQ2_offline_time_mean.pdf

!epstopdf figures/Figure_9_downward_total_radiative_flux_at_surface_bias_offline_time_mean.eps
!pdftops -eps figures/Figure_9_downward_total_radiative_flux_at_surface_bias_offline_time_mean.pdf
!rm figures/Figure_9_downward_total_radiative_flux_at_surface_bias_offline_time_mean.pdf