### Prognostic surface radiation evaluation

Compare physics- and ML-derived surface radiative fluxes against high-res verif in several prognostic runs

In [None]:
import intake
import xarray as xr
import numpy as np
import fsspec
import os
import cftime
from dask.diagnostics import ProgressBar
from matplotlib import pyplot as plt
import matplotlib
from cycler import cycler
matplotlib.rcParams['axes.prop_cycle'] = cycler(color=['r', 'b', 'k'])
from vcm.catalog import catalog as CATALOG
from vcm.fv3.metadata import standardize_fv3_diagnostics
import fv3viz

In [None]:
RUN_CONFIGS = [
    {
        'name': 'coarsened-fine-res',
        'path': '40day_c48_gfsphysics_15min_may2020',
        'verif': True
    },
    {
        'name': 'baseline',
        'path': 'gs://vcm-ml-experiments/2021-04-13/baseline-physics-run-20160805-start-rad-step-1800s',
    },
    {
        'name': 'NN-ensemble-NN-rad',
        'path': 'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/initial_conditions_runs/20160805.000000_with_RRTMG',
    },
    {
        'name': 'NN-ensemble-RF-rad',
        'path': 'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/initial_conditions_runs_rf_rad/20160805.000000',
    },
    {
        'name': 'RF',
        'path': 'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/rf/initial_conditions_runs/20160805.000000_with_RRTMG',
    }
]

In [None]:
MASK = CATALOG['landseamask/c48'].to_dask()['land_sea_mask']
GRID = CATALOG['grid/c48'].to_dask()
AREA = GRID['area']
SFC_FLUX_VARS = [
    'DLWRFsfc_from_RRTMG',
    'DLWRFsfc',
    'DSWRFsfc_from_RRTMG',
    'DSWRFsfc',
    'NSWRFsfc_from_RRTMG',
    'NSWRFsfc'
]
MAPPABLE_VAR_KWARGS = {
    "coord_x_center": "x",
    "coord_y_center": "y",
    "coord_x_outer": "x_interface",
    "coord_y_outer": "y_interface",
    "coord_vars": {
        "lonb": ["y_interface", "x_interface", "tile"],
        "latb": ["y_interface", "x_interface", "tile"],
        "lon": ["y", "x", "tile"],
        "lat": ["y", "x", "tile"],
    },
}
OUTDIR = 'figures'

In [None]:
def weighted_average(da, weights, dims):
    return (da*weights).sum(dim=dims)/weights.sum(dim=dims)

def weighted_rms(da, weights, dims):
    weighted_square = ((da**2)*weights).sum(dim=dims)/weights.sum(dim=dims)
    return np.sqrt(weighted_square)

In [None]:
def time_mean(ds, area, mask, dims=['x', 'y', 'tile']):
    time_mean = {}
    land_area = area.where(mask == 1.0)
    for var in ds:
        time_mean[var + '_time_mean'] = ds[var].mean(dim='time', keep_attrs=True)
        land_da = time_mean[var + '_time_mean'].where(mask == 1.0)
        time_mean[var + '_time_land_mean'] = weighted_average(
            land_da,
            land_area,
            dims=dims
        )
        time_mean[var + '_time_land_rms'] = weighted_rms(
            land_da,
            land_area,
            dims=dims
        )
    return xr.Dataset(time_mean)

In [None]:
def land_rms(ds, area, mask, dims=['x', 'y', 'tile']):
    land_area = area.where(mask == 1.0)
    land_rmse = {}
    for var in ds:
        da = ds[var].where(mask == 1.0)
        land_rmse[var + '_rms'] = weighted_rms(da, land_area, dims).assign_attrs(**ds[var].attrs)
        
    return xr.Dataset(land_rmse)

In [None]:
def land_min_mean_max(ds, area, mask, dims=['x', 'y', 'tile']):
    land_area = area.where(mask == 1.0)
    land_min, land_mean, land_max = {}, {}, {}
    for var in ds:
        da = ds[var].where(mask == 1.0)
        land_min[var] = da.min(dim=dims, keep_attrs=True)
        land_mean[var] = weighted_average(da, land_area, dims).assign_attrs(**ds[var].attrs)
        land_max[var] = da.max(dim=dims, keep_attrs=True)
    datasets = []
    for ds in [land_min, land_mean, land_max]:
        datasets.append(xr.Dataset(ds))
    return xr.concat(datasets, dim=xr.DataArray(['min', 'mean', 'max'], dims='metric', name='metric'))

In [None]:
class Fv3gfsRun:
    
    def __init__(self, name: str, path: str, verif: bool = False):
        self.name = name
        self._path = path
        self.verif = verif
        self.flux_data = self._open_fluxes()
        self.land_flux_stats: Optional[xr.Dataset] = None
        
    def _open_fluxes(self):
        if self.verif:
            ds = CATALOG[self._path].to_dask()
        else: 
            path = os.path.join(self._path, 'sfc_dt_atmos.zarr')
            ds = intake.open_zarr(path, consolidated=True).to_dask()
        ds = standardize_fv3_diagnostics(ds).resample({'time': '3H'}, label='right').nearest()
        return self._derive_fluxes(ds)
    
    def add_biases(self, verif_ds: xr.Dataset, flux_varnames=SFC_FLUX_VARS):
        if not self.verif:
            biases = {}
            for var in flux_varnames:
                var_root = var.replace('_from_RRTMG', '')
                if var in self.flux_data.data_vars:
                    biases[var + '_bias'] = (self.flux_data[var] - verif_ds[var_root]).assign_attrs({
                    'long_name': self.flux_data[var].attrs.get('long_name', var) + ' bias',
                    'units': self.flux_data[var].attrs.get('units')
                })
                else:
                    biases[var + '_bias'] = xr.zeros_like(verif_ds[var_root])
            self.flux_data = xr.merge([self.flux_data, xr.Dataset(biases)])

    def compute_land_stats(self, area: xr.DataArray=AREA, mask: xr.DataArray=MASK):
        bias_vars = [var for var in self.flux_data.data_vars if '_bias' in var]
        min_mean_max_vars = list(set(self.flux_data.data_vars) - set(bias_vars))
        stats = xr.merge([
            land_min_mean_max(self.flux_data[min_mean_max_vars], area, mask),
            land_rms(self.flux_data[bias_vars], area, mask),
            time_mean(self.flux_data[bias_vars], area, mask)
        ], join='inner')
        
        with ProgressBar():
            self.stats = stats.load()
            
    @staticmethod
    def _derive_fluxes(ds, flux_varnames=SFC_FLUX_VARS):
        ds['NSWRFsfc'] = (ds['DSWRFsfc'] - ds['USWRFsfc']).assign_attrs({
            'long_name': 'net shortwave radiative flux at surface',
            'units': 'W/m^2'
        })
        if 'DSWRFsfc_from_RRTMG' in ds.data_vars and 'USWRFsfc_from_RRTMG' in ds.data_vars:
            ds['NSWRFsfc_from_RRTMG'] = (ds['DSWRFsfc_from_RRTMG'] - ds['USWRFsfc_from_RRTMG']).assign_attrs({
                'long_name': 'net shortwave radiative flux at surface',
                'units': 'W/m^2'
            })
        fluxes = {}
        for var in flux_varnames:
            if var in ds.data_vars:
                fluxes[var] = ds[var]
        
        return xr.Dataset(fluxes)
        
        
    

In [None]:
runs = []
for run_config in RUN_CONFIGS:
    print(run_config['name'])
    fv3gfs_run = Fv3gfsRun(**run_config)
    if fv3gfs_run.verif == True:
        verif_ds = fv3gfs_run.flux_data
    else:
        fv3gfs_run.add_biases(verif_ds)
    fv3gfs_run.compute_land_stats()
    runs.append(fv3gfs_run)

In [None]:
def plot_min_mean_max(da, label, ax, alpha=0.25):
    da.sel(metric='mean').plot(ax=ax, label=label)
    ax.fill_between(da.time.values, da.sel(metric='min'), da.sel(metric='max'), alpha=alpha
    )

In [None]:
def plot_var(run_list, varname):
    fig, ax = plt.subplots(1, 1)
    for fv3gfs_run in run_list:
        label = fv3gfs_run.name
        plot_min_mean_max(fv3gfs_run.stats[varname], label, ax)
        RRTMG_var = varname + '_from_RRTMG'
        if RRTMG_var in fv3gfs_run.stats.data_vars:
            plot_min_mean_max(fv3gfs_run.stats[RRTMG_var], fv3gfs_run.name + ' (piggyback physics)', ax)
    ax.set_xlim([cftime.DatetimeJulian(2016, 8, 5, 0, 0, 0), cftime.DatetimeJulian(2016, 9, 10, 0, 0, 0)])
    ax.set_ylabel(fv3gfs_run.stats[varname].attrs.get('units', 'W/m^2'))
    ax.set_title(varname)
    ax.legend()
    ax.grid()
    fig.set_size_inches([8, 5])
    fig.set_dpi(100)
    fig.tight_layout()
    fig.savefig(f"{OUTDIR}/prognostic-{varname}-drift-{fv3gfs_run.name}.png", facecolor='w', bbox_inches='tight')

In [None]:
plot_var([runs[ind] for ind in [0, 1]], 'DSWRFsfc')

In [None]:
plot_var([runs[ind] for ind in [0, 2]], 'DSWRFsfc')

In [None]:
plot_var([runs[ind] for ind in [0, 3]], 'DSWRFsfc')

In [None]:
plot_var([runs[ind] for ind in [0, 4]], 'DSWRFsfc')

In [None]:
def rms_plot(run_list, varname):
    fig, ax = plt.subplots(1, 1)
    ax.set_prop_cycle(cycler(color=['r', 'r', 'b', 'b'], ls=['-', '--', '-', '--']))
    for fv3gfs_run in run_list:
        ax.plot(fv3gfs_run.stats.time, fv3gfs_run.stats[var + '_bias_rms'], label=fv3gfs_run.name)
        ax.plot(fv3gfs_run.stats.time, fv3gfs_run.stats[var + '_from_RRTMG_bias_rms'], label=fv3gfs_run.name + ' (piggyback physics)', lw=1)
    ax.set_xlim([cftime.DatetimeJulian(2016, 8, 5, 0, 0, 0), cftime.DatetimeJulian(2016, 9, 10, 0, 0, 0)])
    ax.set_ylabel(fv3gfs_run.stats[var + '_bias_rms'].attrs.get('units'))
    ax.set_title(f'RMSE of {varname}')
    ax.legend()
    ax.grid()
    fig.set_size_inches([8, 5])
    fig.set_dpi(100)
    fig.tight_layout()
    fig.savefig(f"{OUTDIR}/prognostic-{varname}-rms-{run_list[0].name}-{run_list[1].name}.png", facecolor='w', bbox_inches='tight')

In [None]:
# for var in ['DLWRFsfc', 'DSWRFsfc', 'NSWRFsfc']:
#     rms_plot([runs[ind] for ind in [1, 3]], var)

In [None]:
# for var in ['DLWRFsfc', 'DSWRFsfc', 'NSWRFsfc']:
#     rms_plot([runs[ind] for ind in [2, 3]], var)

In [None]:
# for var in ['DLWRFsfc', 'DSWRFsfc', 'NSWRFsfc']:
#     rms_plot([runs[ind] for ind in [3, 4]], var)

In [None]:
def time_mean_plot(run_list, varname, grid=GRID):
    varnames = [varname + '_bias_time_mean', varname + '_from_RRTMG_bias_time_mean']
    dataarrays = []
    mean_dataarrays = []
    rms_dataarrays = []
    for fv3gfs_run in run_list:
        run_da = (
            fv3gfs_run
            .stats[varnames]
            .to_array(name=varname)
            .assign_coords({'variable': ['actual', 'piggyback physics']})
            .assign_attrs(**fv3gfs_run.stats[varname + '_bias_time_mean'].attrs)
        )
        mean_varnames = [varname + '_bias_time_land_mean', varname + '_from_RRTMG_bias_time_land_mean']
        mean_da = (
            fv3gfs_run
            .stats[mean_varnames]
            .to_array(name=varname)
            .assign_coords({'variable': ['actual', 'piggyback physics']})
        )
        rms_varnames = [varname + '_bias_time_land_rms', varname + '_from_RRTMG_bias_time_land_rms']
        rms_da = (
            fv3gfs_run
            .stats[rms_varnames]
            .to_array(name=varname)
            .assign_coords({'variable': ['actual', 'piggyback physics']})
        )
        dataarrays.append(run_da)
        mean_dataarrays.append(mean_da)
        rms_dataarrays.append(rms_da)
    da = xr.concat(dataarrays, dim=xr.DataArray([run.name for run in run_list], dims='run', name='run'))
    mean_da = xr.concat(mean_dataarrays, dim=xr.DataArray([run.name for run in run_list], dims='run', name='run'))
    rms_da = xr.concat(rms_dataarrays, dim=xr.DataArray([run.name for run in run_list], dims='run', name='run'))
    _, _, _, _, fg = fv3viz.plot_cube(
        fv3viz.mappable_var(xr.merge([da, grid]), varname, **MAPPABLE_VAR_KWARGS),
        col='variable',
        row='run',
        vmax=75
    )
    fg.set_titles('{value}')
    for i, row in enumerate(fg.axes):
        for j, ax in enumerate(row):
            name_dict = fg.name_dicts[i,j]
            mean = mean_da.sel(name_dict).item()
            rms = rms_da.sel(name_dict).item()
            ax.set_title(f"{name_dict['variable']}\nland mean: {mean:4.1f}, rms: {rms:4.1f}")
    fg.fig.set_size_inches([12, 13])
    fg.fig.set_dpi(100)
    fg.fig.savefig(f"{OUTDIR}/prognostic-{varname}-time-mean.png", facecolor='w', bbox_inches='tight')

In [None]:
for var in ['DSWRFsfc', 'DLWRFsfc']:
    time_mean_plot(runs[1:], var)