### Prognostic IC ensemble RMSE and drifts

Plotting weather forecast scale RMSEs and means for initial conditions and NN random seed ensembles of N2F-trained prognostic runs

In [None]:
import xarray as xr
from matplotlib import pyplot as plt
import matplotlib
from cycler import cycler
matplotlib.rcParams['font.size'] = 12
matplotlib.rcParams['axes.prop_cycle'] = cycler(color=['gray', 'b', 'orange', 'k'])
import fv3viz
import fsspec
import json
from vcm.catalog import catalog as CATALOG
import os
import numpy as np
from vcm.fv3.metadata import standardize_fv3_diagnostics
from vcm import open_remote_nc
from dask.diagnostics import ProgressBar
import warnings

In [None]:
INITIAL_CONDITIONS_ENSEMBLE = ['20160805.000000', '20160813.000000', '20160821.000000', '20160829.000000']
BASELINE_NAME = 'baseline-no-ML'
RUN_TYPE_ROOT_TEMPLATE = {
    BASELINE_NAME: 'vcm-ml-experiments/2021-04-13/baseline-physics-run-{ic:s}-start-rad-step-1800s',
    'temperature-moisture-RF': 'vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/control-dq1-dq2-rf/initial_conditions_runs/{ic:s}',
    'temperature-moisture-winds-RF': 'vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/control-rf/prognostic_run_tendencies_only_ics/{ic:s}',
    'temperature-moisture-winds-prescribed-sfc-RF': 'vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/rf/initial_conditions_runs/{ic:s}',
    'temperature-moisture-winds-prescribed-sfc-NN-ensemble': 'vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/initial_conditions_runs/{ic:s}',
    'temperature-moisture-prescribed-sfc-RF': 'vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/rf/initial_conditions_runs/{ic:s}',
    'temperature-moisture-prescribed-sfc-NN-ensemble': 'vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/nn-ensemble-model/initial_conditions_runs/{ic:s}',
}
PROGNOSTIC_RUN_DIAGS_ROOT = 'gs://vcm-ml-archive/prognostic_run_diags/'
DIAGS_NC = 'diags.nc'
DIAGS_ZARR = 'diags.zarr'
DROP_VARS = ['dqu_pressure_level_zonal_time_mean', 'dq2_pressure_level_zonal_time_mean', 'dq1_pressure_level_zonal_time_mean', 'dqv_pressure_level_zonal_time_mean']
DRIFT_VARS = [
    'total_precip_to_surface_spatial_mean_physics_land',
    'pwat_spatial_mean_dycore_global',
    'tmplowest_spatial_mean_dycore_global',
    'lhtflsfc_spatial_mean_physics_land',
    'column_integrated_pq1_spatial_mean_physics_global',
    'column_integrated_q1_spatial_mean_physics_global',
    'column_integrated_pq2_spatial_mean_physics_global',
    'column_integrated_q2_spatial_mean_physics_global'
]
PRECIP_RATE = 'total_precipitation_rate'
PHYSICS_PRECIP = 'PRATEsfc'
CATALOG_KEY = '40day_c48_gfsphysics_15min_may2020'
SECONDS_PER_DAY = 86400
OUTDIR = 'figures'
TABLES = 'tables'

In [None]:
def abs_time_to_lead_time(ds):
    abs_time = ds.time
    lead_times = (abs_time - abs_time[0]).astype('timedelta64[m]')/np.timedelta64(1440, 'm')
    lead_times = xr.DataArray(lead_times, dims=['time'], attrs={'units': 'days'})
    ds = ds.assign_coords({'lead_time': lead_times}).swap_dims({'time': 'lead_time'})
    return ds.drop_vars('time')


def precipitation_bias(run_precip, verif_precip):
    precip_bias = abs_time_to_lead_time(run_precip - precip_verif).rename('precipitation_bias')
    precip_bias = SECONDS_PER_DAY*precip_bias
    precip_bias = precip_bias.assign_attrs({
        'long_name': 'precipitation bias (coarse minus fine)',
        'units': 'mm/d'
    })
    return precip_bias

In [None]:
precip_verif = standardize_fv3_diagnostics(CATALOG[CATALOG_KEY].to_dask())[PHYSICS_PRECIP].rename('total_precipitation_rate')

In [None]:
run_type_datasets = []
precip_bias_dataarrays = []
for run_type, path_template in RUN_TYPE_ROOT_TEMPLATE.items():
    ensemble_datasets = []
    ensemble_precip_bias_dataarrays = []
    for ic in INITIAL_CONDITIONS_ENSEMBLE:
        if run_type == BASELINE_NAME:
            ic_path = path_template.format(ic=ic[:8])
        else:
            ic_path = path_template.format(ic=ic)
        nc_path = os.path.join(
            PROGNOSTIC_RUN_DIAGS_ROOT,
            ic_path.replace('/', '-'),
            DIAGS_NC
        )
        zarr_path = os.path.join('gs://', ic_path, DIAGS_ZARR)
        print(f"opening {nc_path}")
        fs, *_ = fsspec.get_fs_token_paths(nc_path)
        member_ds = abs_time_to_lead_time(open_remote_nc(fs, nc_path))
        ensemble_datasets.append(member_ds)
        print(f"opening {zarr_path}")
        member_precip = xr.open_zarr(fsspec.get_mapper(zarr_path), consolidated=True)[PRECIP_RATE]
        member_precip_bias = precipitation_bias(member_precip, precip_verif)
        ensemble_precip_bias_dataarrays.append(member_precip_bias)
    run_type_datasets.append(
        xr.concat(
            ensemble_datasets,
            dim=xr.DataArray(INITIAL_CONDITIONS_ENSEMBLE, dims=['ensemble'], name='ensemble'),
            join='inner'
        ).drop_vars(DROP_VARS, errors='ignore')
    )
    precip_bias_dataarrays.append(
        xr.concat(
            ensemble_precip_bias_dataarrays,
            dim=xr.DataArray(INITIAL_CONDITIONS_ENSEMBLE, dims=['ensemble'], name='ensemble')
        )
    )

In [None]:
ds = xr.concat(run_type_datasets, dim=xr.DataArray(list(RUN_TYPE_ROOT_TEMPLATE.keys()), dims=['run_type'], name='run_type'))
precip_bias = xr.concat(precip_bias_dataarrays, dim=xr.DataArray(list(RUN_TYPE_ROOT_TEMPLATE.keys()), dims=['run_type'], name='run_type'))

In [None]:
def add_verification(ds):
    verif_ds = xr.full_like(ds.sel(run_type=BASELINE_NAME), np.nan)
    for var in DRIFT_VARS:
        try:
            verif_ds[var] = (ds[var] - ds[var.replace('spatial_mean', 'mean_bias')]).sel(run_type=BASELINE_NAME)
        except KeyError:
            pass
    verif_ds = verif_ds.assign_coords({'run_type': ['verification']})
    return xr.concat([ds, verif_ds], dim='run_type')

In [None]:
def ensemble_metrics(ds):
    with xr.set_options(keep_attrs=True):
        mean = ds.mean(dim='ensemble').expand_dims({'ensemble': ['mean']})
        minimum = ds.min(dim='ensemble').expand_dims({'ensemble': ['min']})
        maximum = ds.max(dim='ensemble').expand_dims({'ensemble': ['max']})
        stdev = ds.std(dim='ensemble').expand_dims({'ensemble': ['std']})
    return xr.merge([minimum, mean, maximum, stdev])

In [None]:
# only use first 10 day of forecast for consistency
ds = ds.sel(lead_time=slice(None, 10.0))
ds = add_verification(ds)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    ensemble_metrics_ds = ensemble_metrics(ds)
    
ensemble_metrics_ds = ensemble_metrics_ds.reindex({
    'run_type': [
        'baseline-no-ML',
        'temperature-moisture-RF',
        'temperature-moisture-winds-RF',
        'temperature-moisture-winds-prescribed-sfc-RF',
        'temperature-moisture-winds-prescribed-sfc-NN-ensemble',
        'temperature-moisture-prescribed-sfc-RF',
        'temperature-moisture-prescribed-sfc-NN-ensemble'
    ]
})

In [None]:
# dump ic mean rms metric
fig, axes = plt.subplots(3, 1)
ablation_rms_metrics = []
for i, var in enumerate(['h500_rms_global', 'tmp850_rms_global', 'tmp200_rms_global']):
    var_mean_metrics = {'variable': f"{var}_day5_ic_mean", 'units': ensemble_metrics_ds[var].attrs.get('units', '')}
    var_std_metrics = {'variable': f"{var}_day5_ic_std", 'units': ensemble_metrics_ds[var].attrs.get('units', '')}
    for run_type in RUN_TYPE_ROOT_TEMPLATE.keys():
        var_mean_metrics[run_type] = ensemble_metrics_ds.sel(ensemble='mean', lead_time=3.0, run_type=run_type)[var].item()
        var_std_metrics[run_type] = ensemble_metrics_ds.sel(ensemble='std', lead_time=3.0, run_type=run_type)[var].item()
    ablation_rms_metrics.append(var_mean_metrics)
    ablation_rms_metrics.append(var_std_metrics)
    errors = ensemble_metrics_ds.sel(ensemble='std', lead_time=3.0)[var].to_series()
    ensemble_metrics_ds.sel(ensemble='mean', lead_time=3.0)[var].to_series().plot.bar(ax=axes.flatten()[i], yerr=errors, capsize=4)
    if i < 2:
        axes.flatten()[i].set_xticklabels('')
    axes.flatten()[i].set_ylabel(f"{var}_day5_ic_mean")
    axes.flatten()[i].grid(axis='y')
fig.set_size_inches(10, 10)
fig.savefig('figures/rms_bar.png', bbox_inches='tight')
with open(f"{TABLES}/ablation_rms_metrics.json", 'w') as f:
    json.dump(ablation_rms_metrics, f)

In [None]:
RMS_VARS = [var for var in ensemble_metrics_ds.data_vars if 'rms_global' in var]
rename_dict = {}
for var in RMS_VARS:
    rename_dict[var] = f"{ensemble_metrics_ds[var].attrs.get('long_name', var.split('_rms_global')[0])} RMSE [{ensemble_metrics_ds[var].attrs.get('units', '')}]"
ds_rms = (
    ensemble_metrics_ds[RMS_VARS]
    .rename(rename_dict)
    .to_array(dim='variable_names', name='prognostic_ic_ensemble_rms_variables')
    .to_dataset()
    .sel({'run_type': [BASELINE_NAME,
       'temperature-moisture-winds-prescribed-sfc-RF',
       'temperature-moisture-winds-prescribed-sfc-NN-ensemble']})
)

In [None]:
def plot_ensemble(arr, lead_time, ax=None):
    if ax is None:
        ax=plt.gca()
    h = ax.plot(lead_time, arr[1,:,:].T, lw=1.5)
    for i in range(arr.shape[1]):
        ax.fill_between(lead_time, arr[0, i, :], arr[2, i, :], alpha=0.25)
    ax.grid()
    return h

In [None]:
def plot_all_vars(ds, varname, fig_size=[9, 12]):
    fg = xr.plot.FacetGrid(data=ds, col='variable_names', col_wrap=3, sharey=False)
    fg.map(plot_ensemble, varname, lead_time=ds.lead_time)
    ax = fg.axes[0, 0]
    n_lines = ds.sizes['run_type']
    ax.legend(
        ax.get_children()[n_lines:(2*n_lines)],
        ['baseline', 'random forests', 'NN-ensemble'],
        fontsize='x-small', loc=4
    )
    ax.set_xticks(np.arange(11))
    ax.set_xlim([0, 10])
    fg.set_titles(template='{value}', maxchar=50)      
    fg.set_ylabels('')
    fg.set_xlabels('Forecast lead time [days]')
    fg.fig.set_size_inches(fig_size)
    fg.fig.set_dpi(150)
    fg.fig.tight_layout()
    fg.fig.savefig(f'{OUTDIR}/{varname}.png', bbox_inches='tight', facecolor='white')

In [None]:
plot_all_vars(ds_rms, 'prognostic_ic_ensemble_rms_variables')

In [None]:
RMS_SUBSET_VARS = ['h500_rms_global', 'pressfc_rms_global', 'tmp850_rms_global']
rename_dict = {}
for var in RMS_SUBSET_VARS:
    rename_dict[var] = f"{ensemble_metrics_ds[var].attrs.get('long_name', var.split('_rms_global')[0])} RMSE [{ensemble_metrics_ds[var].attrs.get('units', '')}]"
ds_rms_subset = (
    ensemble_metrics_ds[RMS_SUBSET_VARS]
    .rename(rename_dict)
    .to_array(
        dim='variable_names',
        name='prognostic_ic_ensemble_rms_subset_variables'
    )
    .to_dataset()
    .sel({'run_type': [
        BASELINE_NAME,
        'temperature-moisture-winds-prescribed-sfc-RF',
        'temperature-moisture-winds-prescribed-sfc-NN-ensemble'
    ]})
)

In [None]:
plot_all_vars(
    ds_rms_subset,
    'prognostic_ic_ensemble_rms_subset_variables',
    fig_size=[12, 3.5]
)

In [None]:
ensemble_metrics_ds[DRIFT_VARS]
rename_dict = {}
for var in DRIFT_VARS:
    rename_dict[var] = f"{var.replace('_spatial_mean_dycore', '').replace('_spatial_mean_physics', '').replace('_global', '')} [{ensemble_metrics_ds[var].attrs.get('units', '')}]"
ds_drift = (
    ensemble_metrics_ds[DRIFT_VARS]
    .rename(rename_dict)
    .to_array(dim='variable_names', name='prognostic_ic_ensemble_drift_variables')
    .to_dataset()
    .sel({'run_type': [
        BASELINE_NAME,
        'temperature-moisture-winds-prescribed-sfc-RF',
        'temperature-moisture-winds-prescribed-sfc-NN-ensemble'
    ]})
)

In [None]:
plot_all_vars(ds_drift, 'prognostic_ic_ensemble_drift_variables', fig_size=[12, 9])

In [None]:
# take mean precip bias over 10-day forecast period
precip_bias_time_mean = precip_bias.sel(
    lead_time=slice(None, 10.0), run_type=[
        BASELINE_NAME,
        'temperature-moisture-winds-prescribed-sfc-RF',
        'temperature-moisture-winds-prescribed-sfc-NN-ensemble'
    ]
).mean(dim='lead_time', keep_attrs=True).assign_coords({'run_type': ['baseline', 'random forests','NN-ensemble']})

In [None]:
GRID = CATALOG['grid/c48'].to_dask()
MAPPABLE_VAR_KWARGS = {
    "coord_x_center": "x",
    "coord_y_center": "y",
    "coord_x_outer": "x_interface",
    "coord_y_outer": "y_interface",
    "coord_vars": {
        "lonb": ["y_interface", "x_interface", "tile"],
        "latb": ["y_interface", "x_interface", "tile"],
        "lon": ["y", "x", "tile"],
        "lat": ["y", "x", "tile"],
    },
}

def _var_rms(bias, area):
    weights = area/area.mean()
    return np.sqrt(((bias**2)*weights).mean(dim=['x', 'y', 'tile']))

def _var_mean(bias, area):
    weights = area/area.mean()
    return (weights*bias).mean(dim=['x', 'y', 'tile'])

In [None]:
var = 'precipitation_bias'
precip_bias_plottable = (
    fv3viz.mappable_var(
        xr.merge([precip_bias_time_mean, GRID]),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)
with ProgressBar():
    precip_bias_plottable = precip_bias_plottable.load()
    rms_ = _var_rms(precip_bias_plottable[var], GRID['area']).load()
    mean_ = _var_mean(precip_bias_plottable[var], GRID['area']).load()

In [None]:
_, _, _, _, fg = fv3viz.plot_cube(
    precip_bias_plottable,
    col='run_type',
    row='ensemble',
    vmin=-10
)
fg.set_titles('{value}')
for i, row in enumerate(fg.axes):
    for j, ax in enumerate(row):
        run_type = precip_bias_plottable.run_type.isel({'run_type': j}).item() if i == 0 else ''
        ax.set_title(f"{run_type}\nrms: {rms_.values[i,j]:3.2f}, bias: {mean_.values[i,j]:3.2f}")
fg.fig.set_size_inches([12, 10])
fg.fig.set_dpi(150)
fg.fig.savefig(f'{OUTDIR}/prognostic_ic_ensemble_precipitation_bias.png', bbox_inches='tight', facecolor='white')