In [None]:
import numpy as np
import xarray as xr
import fsspec
import matplotlib.pyplot as plt
import fv3viz
import os
from fv3net.diagnostics.prognostic_run.load_run_data import load_physics, load_grid, load_verification
from fv3net.diagnostics.prognostic_run import get_verification_entries
from fv3net.diagnostics.prognostic_run.derived_variables import physics_variables
from vcm.catalog import catalog
from dask.diagnostics import ProgressBar
plt.rcParams['figure.dpi'] = 300
xr.set_options(keep_attrs=True)

In [None]:
MAPPABLE_VAR_KWARGS = {
    "coord_x_center": "x",
    "coord_y_center": "y",
    "coord_x_outer": "x_interface",
    "coord_y_outer": "y_interface",
    "coord_vars": {
        "lonb": ["y_interface", "x_interface", "tile"],
        "latb": ["y_interface", "x_interface", "tile"],
        "lon": ["y", "x", "tile"],
        "lat": ["y", "x", "tile"],
    },
}

def global_mean(da, w, dims=['tile', 'x', 'y']):
    gm = (da * w).sum(dims) / w.sum(dims)
    return gm.assign_attrs(da.attrs)

GRID_VARS = ['lon', 'lat', 'lonb', 'latb', 'area']

REQUIRED_VARIABLES = [
    'DLWRFsfc',
    'DLWRFsfc_from_RRTMG',
    'DSWRFsfc',
    'DSWRFsfc_from_RRTMG',
    'ULWRFsfc',
    'USWRFsfc',
    'USWRFsfc_from_RRTMG',
    'ULWRFtoa',
    'USWRFtoa',
    'DSWRFtoa',
    'column_integrated_dQ1',
    'column_integrated_dQ2',
    'column_integrated_nQ1',
    'column_integrated_nQ2',
] + GRID_VARS

### Load physics diagnostics for runs

In [None]:
urls = {
    'baseline': 'gs://vcm-ml-experiments/2021-04-13/baseline-physics-run-20160805-start-rad-step-1800s',
    'nudged': 'gs://vcm-ml-experiments/2021-04-13-n2f-c3072/3-hrly-ave-rad-precip-setting-30-min-rad-timestep-shifted-start-tke-edmf',
    '$TqR$-RF': 'gs://vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/rf/initial_conditions_runs/20160805.000000_with_RRTMG',
    '$TquvR$-RF': 'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/rf/initial_conditions_runs/20160805.000000_with_RRTMG',
    '$TqR$-NN': 'gs://vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/nn-ensemble-model/initial_conditions_runs/20160805.000000_with_RRTMG'
}
start_time = '2016-08-05T00:15:00'
end_time = '2016-09-10T00:00:00'

In [None]:
verif_entries = get_verification_entries('40day_may2020', catalog)['physics']
diags = []
for run, url in urls.items():
    tmp = load_physics(url, catalog)
    tmp = physics_variables(tmp.sel(time=slice(start_time, end_time)))
    verif = load_verification(verif_entries, catalog)
    verif = physics_variables(verif.sel(time=slice(start_time, end_time)))
    grid = load_grid(catalog)
    for varname in ['DLWRFsfc_from_RRTMG', 'DSWRFsfc_from_RRTMG', 'USWRFsfc_from_RRTMG']:
        # add copies for "from_RRTMG" variables if not existing 
        # i.e. for baseline/verif, actual applied fluxes are same as from_RRTMG
        shortname = varname.split('_')[0]
        if varname not in tmp:
            tmp[varname] = tmp[shortname]
            tmp[varname].attrs['long_name'] = tmp[shortname].attrs['long_name'] + ' due to RRTMG'
        if varname not in verif:
            verif[varname] = verif[shortname]
            verif[varname].attrs['long_name'] = verif[shortname].attrs['long_name'] + ' due to RRTMG'
    variables_to_drop = [v for v in tmp.data_vars if v not in REQUIRED_VARIABLES]
    diags.append(tmp.drop_vars(variables_to_drop).assign_coords(experiment=run))

variables_to_drop = [v for v in verif.data_vars if v not in REQUIRED_VARIABLES]
verif = verif.drop_vars(variables_to_drop)
diags.append(xr.merge([verif, grid]).assign_coords(experiment='fine-res'))
diags = xr.concat(diags, dim='experiment')
diags = diags.set_coords(['lon', 'lat', 'lonb', 'latb', 'area'])

### Compute $<dQm>$

In [None]:
L = 2.5e6
seconds_per_day = 86400

diags['column_integrated_dQm'] = diags.column_integrated_dQ1 + L * diags.column_integrated_dQ2 / seconds_per_day
diags['column_integrated_nQm'] = diags.column_integrated_nQ1 + L * diags.column_integrated_nQ2 / seconds_per_day
diags['column_integrated_dQm_or_nQm'] = diags['column_integrated_dQm'] + diags['column_integrated_nQm']
diags['column_integrated_dQm_or_nQm'].attrs['long_name'] = 'column MSE tendency due to nudging or ML'

### Compute net fluxes

In [None]:
diags['net_downward_toa_flux'] =  diags.DSWRFtoa - diags.USWRFtoa - diags.ULWRFtoa
diags['net_downward_toa_flux'].attrs['long_name'] = 'net downward TOA radiative flux due to RRTMG'

diags['net_downward_sfc_rad_flux_RRTMG'] = diags.DLWRFsfc_from_RRTMG + diags.DSWRFsfc_from_RRTMG - diags.USWRFsfc_from_RRTMG - diags.ULWRFsfc
diags['net_downward_sfc_rad_flux_RRTMG'].attrs['long_name'] = 'net downward surface radiative flux due to RRTMG'

diags['net_downward_sfc_rad_flux_applied'] = diags.DLWRFsfc + diags.DSWRFsfc - diags.USWRFsfc - diags.ULWRFsfc
diags['net_downward_sfc_rad_flux_applied'].attrs['long_name'] = 'net downward surface radiative flux felt by surface'

diags['net_downward_flux_due_to_sfc_radiation_fix'] = diags.net_downward_sfc_rad_flux_applied - diags.net_downward_sfc_rad_flux_RRTMG
diags['net_downward_flux_due_to_sfc_radiation_fix'].attrs['long_name'] = 'implied downward flux due to predicting or prescribing surface radiation'

diags['implied_downward_toa_flux'] = diags.net_downward_toa_flux + diags.column_integrated_dQm_or_nQm + diags.net_downward_flux_due_to_sfc_radiation_fix
diags['implied_downward_toa_flux'].attrs['long_name'] = 'implied net downward TOA flux'

In [None]:
variables_of_interest = [
    'net_downward_flux_due_to_sfc_radiation_fix',
    'net_downward_toa_flux',
    'column_integrated_dQm_or_nQm',
    'implied_downward_toa_flux',
]

### Plot time-mean over first 7 days of run

In [None]:
plot_start_time = '2016-08-05T00:15:00'
plot_end_time = '2016-08-12T00:00:00'
with ProgressBar():
    diags_time_mean_7day = diags[variables_of_interest + GRID_VARS].sel(
        time=slice(plot_start_time, plot_end_time)
    ).mean('time').compute()

In [None]:
for variable in variables_of_interest:
    gm = global_mean(diags_time_mean_7day[variable], diags.area)
    mv = fv3viz.mappable_var(diags_time_mean_7day, variable, **MAPPABLE_VAR_KWARGS)
    fg = fv3viz.plot_cube(mv, col='experiment')[-1]
    axes = fg.axes
    for i, ax in enumerate(axes[0]):
        experiment = diags_time_mean_7day.experiment.values[i]
        ax.set_title(f'{experiment} ({gm.values[i]:.1f} W/m**2)')
    fg.fig.set_size_inches(24, 3)