### prognostic-ic-ensemble

Plotting weather forecast scale RMSEs and means for initial conditions ensembles of prognostic runs for different ML configurations (RFs, NNs, baseline)

In [None]:
import xarray as xr
from matplotlib import pyplot as plt
import matplotlib
from cycler import cycler
matplotlib.rcParams['font.size'] = 12
matplotlib.rcParams['axes.prop_cycle'] = cycler(color=['r', 'b', 'm', 'k'])
import fv3viz
import fsspec
from vcm.catalog import catalog as CATALOG
import os
import numpy as np
from vcm.fv3.metadata import standardize_fv3_diagnostics
from dask.diagnostics import ProgressBar

In [None]:
INITIAL_CONDITIONS_ENSEMBLE = ['20160805.000000', '20160813.000000', '20160821.000000' ,'20160829.000000']
RUN_TYPE_ROOT_TEMPLATE = {
    'baseline': 'vcm-ml-experiments/2021-04-13/baseline-physics-run-{ic:s}-start-rad-step-1800s',
    'RF (tendencies and radiation)': 'vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/rf/initial_conditions_runs/{ic:s}',
    'NN-ensemble (tendencies), RF (radiation)': 'vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/initial_conditions_runs_rf_rad/{ic:s}'
}
PROGNOSTIC_RUN_DIAGS_ROOT = 'gs://vcm-ml-archive/prognostic_run_diags/'
DIAGS_NC = 'diags.nc'
DIAGS_ZARR = 'diags.zarr'
DROP_VARS = ['dqu_pressure_level_zonal_time_mean', 'dq2_pressure_level_zonal_time_mean', 'dq1_pressure_level_zonal_time_mean', 'dqv_pressure_level_zonal_time_mean']
DRIFT_VARS = [
    'total_precip_to_surface_spatial_mean_physics_land',
    'pwat_spatial_mean_dycore_global',
    'tmplowest_spatial_mean_dycore_global',
    'lhtflsfc_spatial_mean_physics_land',
    'column_integrated_pq1_spatial_mean_physics_global',
    'column_integrated_q1_spatial_mean_physics_global',
    'column_integrated_pq2_spatial_mean_physics_global',
    'column_integrated_q2_spatial_mean_physics_global'
]
PRECIP_RATE = 'total_precipitation_rate'
PHYSICS_PRECIP = 'PRATEsfc'
CATALOG_KEY = '40day_c48_gfsphysics_15min_may2020'
SECONDS_PER_DAY = 86400
OUTDIR = 'figures'

In [None]:
def abs_time_to_lead_time(ds):
    abs_time = ds.time
    lead_times = (abs_time - abs_time[0]).astype('timedelta64[m]')/np.timedelta64(1440, 'm')
    lead_times = xr.DataArray(lead_times, dims=['time'], attrs={'units': 'days'})
    ds = ds.assign_coords({'lead_time': lead_times}).swap_dims({'time': 'lead_time'})
    return ds.drop_vars('time')


def precipitation_bias(run_precip, verif_precip):
    precip_bias = abs_time_to_lead_time(run_precip - precip_verif).rename('precipitation_bias')
    precip_bias = SECONDS_PER_DAY*precip_bias
    precip_bias = precip_bias.assign_attrs({
        'long_name': 'precipitation bias (coarse minus fine)',
        'units': 'mm/d'
    })
    return precip_bias

In [None]:
precip_verif = standardize_fv3_diagnostics(CATALOG[CATALOG_KEY].to_dask())[PHYSICS_PRECIP].rename('total_precipitation_rate')

run_type_datasets = []
precip_bias_dataarrays = []
for run_type, path_template in RUN_TYPE_ROOT_TEMPLATE.items():
    ensemble_datasets = []
    ensemble_precip_bias_dataarrays = []
    for ic in INITIAL_CONDITIONS_ENSEMBLE:
        if run_type == 'baseline':
            ic_path = path_template.format(ic=ic[:8])
        else:
            ic_path = path_template.format(ic=ic)
        nc_path = os.path.join(
            PROGNOSTIC_RUN_DIAGS_ROOT,
            ic_path.replace('/', '-'),
            DIAGS_NC
        )
        zarr_path = os.path.join('gs://', ic_path, DIAGS_ZARR)
        print(f"opening {nc_path}")
        with fsspec.open(nc_path, 'rb') as f:
            member_ds = abs_time_to_lead_time(xr.open_dataset(f, engine="h5netcdf").load())
        ensemble_datasets.append(member_ds)
        print(f"opening {zarr_path}")
        member_precip = xr.open_zarr(fsspec.get_mapper(zarr_path), consolidated=True)[PRECIP_RATE]
        member_precip_bias = precipitation_bias(member_precip, precip_verif)
        ensemble_precip_bias_dataarrays.append(member_precip_bias)
    run_type_datasets.append(
        xr.concat(
            ensemble_datasets,
            dim=xr.DataArray(INITIAL_CONDITIONS_ENSEMBLE, dims=['ensemble'], name='ensemble'),
            join='inner'
        ).drop_vars(DROP_VARS, errors='ignore')
    )
    precip_bias_dataarrays.append(
        xr.concat(
            ensemble_precip_bias_dataarrays,
            dim=xr.DataArray(INITIAL_CONDITIONS_ENSEMBLE, dims=['ensemble'], name='ensemble')
        )
    )
ds = xr.concat(run_type_datasets, dim=xr.DataArray(list(RUN_TYPE_ROOT_TEMPLATE.keys()), dims=['run_type'], name='run_type'))
precip_bias = xr.concat(precip_bias_dataarrays, dim=xr.DataArray(list(RUN_TYPE_ROOT_TEMPLATE.keys()), dims=['run_type'], name='run_type'))

In [None]:
def add_verification(ds):
    verif_ds = xr.full_like(ds.sel(run_type='baseline'), np.nan)
    for var in DRIFT_VARS:
        try:
            verif_ds[var] = (ds[var] - ds[var.replace('spatial_mean', 'mean_bias')]).sel(run_type='baseline')
        except KeyError:
            pass
    verif_ds = verif_ds.assign_coords({'run_type': ['verification']})
    return xr.concat([ds, verif_ds], dim='run_type')

In [None]:
def ensemble_metrics(ds):
    with xr.set_options(keep_attrs=True):
        median = ds.median(dim='ensemble').expand_dims({'ensemble': ['median']})
        minimum = ds.min(dim='ensemble').expand_dims({'ensemble': ['min']})
        maximum = ds.max(dim='ensemble').expand_dims({'ensemble': ['max']})
    return xr.merge([minimum, median, maximum])

In [None]:
# only use first 10 day of forecast for consistency
ds = ds.sel(lead_time=slice(None, 10.0))
ds = add_verification(ds)
ensemble_metrics_ds = ensemble_metrics(ds)

In [None]:
RMS_VARS = [var for var in ensemble_metrics_ds.data_vars if 'rms_global' in var]
rename_dict = {}
for var in RMS_VARS:
    rename_dict[var] = f"{ensemble_metrics_ds[var].attrs.get('long_name', var.split('_rms_global')[0])} RMSE [{ensemble_metrics_ds[var].attrs.get('units', '')}]"
ds_rms = (
    ensemble_metrics_ds[RMS_VARS]
    .rename(rename_dict)
    .to_array(dim='variable_names', name='rms_variables')
    .to_dataset()
    .sel({'run_type': ['baseline', 'RF (tendencies and radiation)', 'NN-ensemble (tendencies), RF (radiation)']})
)

In [None]:
def plot_ensemble(arr, lead_time, ax=None):
    if ax is None:
        ax=plt.gca()
    h = ax.plot(lead_time, arr[1,:,:].T, lw=1.5)
    for i in range(arr.shape[1]):
        ax.fill_between(lead_time, arr[0, i, :], arr[2, i, :], alpha=0.25)
    ax.grid()
    return h

In [None]:
def plot_all_vars(ds, varname, fig_size=[9, 12]):
    fg = xr.plot.FacetGrid(data=ds, col='variable_names', col_wrap=3, sharey=False)
    fg.map(plot_ensemble, varname, lead_time=ds.lead_time)
    ax = fg.axes[0, 0]
    n_lines = ds.sizes['run_type']
    ax.legend(ax.get_children()[n_lines:(2*n_lines)], ds.run_type.values, fontsize='x-small', loc=4)
    ax.set_xticks(np.arange(11))
    ax.set_xlim([0, 10])
    fg.set_titles(template='{value}', maxchar=50)      
    fg.set_ylabels('')
    fg.set_xlabels('Forecast lead time [days]')
    fg.fig.set_size_inches(fig_size)
    fg.fig.set_dpi(150)
    fg.fig.tight_layout()
    fg.fig.savefig(f'{OUTDIR}/prognostic-ic-{varname}.png', bbox_inches='tight', facecolor='white')

In [None]:
plot_all_vars(ds_rms, 'rms_variables')

In [None]:
ensemble_metrics_ds[DRIFT_VARS]
rename_dict = {}
for var in DRIFT_VARS:
    rename_dict[var] = f"{var.replace('_spatial_mean_dycore', '').replace('_spatial_mean_physics', '').replace('_global', '')} [{ensemble_metrics_ds[var].attrs.get('units', '')}]"
ds_drift = (
    ensemble_metrics_ds[DRIFT_VARS]
    .rename(rename_dict)
    .to_array(dim='variable_names', name='drift_variables')
    .to_dataset()
)

In [None]:
plot_all_vars(ds_drift, 'drift_variables', fig_size=[12, 9])

In [None]:
# take mean precip bias over 10-day forecast period
precip_bias = precip_bias.sel(lead_time=slice(None, 10.0)).mean(dim='lead_time', keep_attrs=True)

In [None]:
GRID = CATALOG['grid/c48'].to_dask()
MAPPABLE_VAR_KWARGS = {
    "coord_x_center": "x",
    "coord_y_center": "y",
    "coord_x_outer": "x_interface",
    "coord_y_outer": "y_interface",
    "coord_vars": {
        "lonb": ["y_interface", "x_interface", "tile"],
        "latb": ["y_interface", "x_interface", "tile"],
        "lon": ["y", "x", "tile"],
        "lat": ["y", "x", "tile"],
    },
}

def _var_rms(bias, area):
    weights = area/area.mean()
    return np.sqrt(((bias**2)*weights).mean(dim=['x', 'y', 'tile']))

def _var_mean(bias, area):
    weights = area/area.mean()
    return (weights*bias).mean(dim=['x', 'y', 'tile'])

In [None]:
var = 'precipitation_bias'
precip_bias_plottable = (
    fv3viz.mappable_var(
        xr.merge([precip_bias, GRID]),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)
with ProgressBar():
    precip_bias_plottable = precip_bias_plottable.load()
    rms_ = _var_rms(precip_bias_plottable[var], GRID['area']).load()
    mean_ = _var_mean(precip_bias_plottable[var], GRID['area']).load()

In [None]:
_, _, _, _, fg = fv3viz.plot_cube(
    precip_bias_plottable,
    col='run_type',
    row='ensemble',
    vmin=-10
)
fg.set_titles('{value}')
for i, row in enumerate(fg.axes):
    for j, ax in enumerate(row):
        run_type = precip_bias_plottable.run_type.isel({'run_type': j}).item() if i == 0 else ''
        ax.set_title(f"{run_type}\nrms: {rms_.values[i,j]:3.2f}, bias: {mean_.values[i,j]:3.2f}")
fg.fig.set_size_inches([12, 10])
fg.fig.set_dpi(150)
fg.fig.savefig(f'{OUTDIR}/prognostic-ic-{var}.png', bbox_inches='tight', facecolor='white')