### Nudged training data surface biases

make figures demonstrating surface biases in the N2F nudged training runs (with and without prescribing states)

In [None]:
import xarray as xr
import numpy as np
import fsspec
from matplotlib import pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 14})
from vcm.catalog import catalog
from vcm.fv3.metadata import standardize_fv3_diagnostics
from dask.diagnostics import ProgressBar
import os

In [None]:
CASES = dict(
    free_running = 'gs://vcm-ml-experiments/2021-04-13/baseline-physics-run-20160801.010000-start-rad-step-1800s',
    nudged = 'gs://vcm-ml-experiments/2021-04-13-n2f-c3072/3-hrly-ave-control-30-min-rad-timestep-shifted-start-tke-edmf',
    nudged_prescribed_sfc_rad_precip = 'gs://vcm-ml-experiments/2021-04-13-n2f-c3072/3-hrly-ave-rad-precip-setting-30-min-rad-timestep-shifted-start-tke-edmf'
)

OUTPUTDIR = '.'

In [None]:
plot_vars = {
    'total_precip_to_surface': 'precipitation rate',
    'LHTFLsfc': 'latent heat flux',
    'SHTFLsfc': 'sensible heat flux',
    'DSWRFsfc': 'sfc. downward SW',
    'DLWRFsfc': 'sfc. downward LW'
}

In [None]:
area = catalog['grid/c48'].to_dask()['area']
mask = catalog['landseamask/c48'].to_dask()['land_sea_mask']

def land_mean(da, mask, area):
    area = area.where(mask == 1.0)
    mean = weighted_mean(da, area)
    return mean.assign_attrs(da.attrs)
    
def weighted_mean(var, weights):
    weights = weights/weights.mean()
    return (weights*var).mean(dim=['x', 'y', 'tile'])

In [None]:
SECONDS_PER_DAY = 86400

def _total_precip_to_surface(ds: xr.Dataset) -> xr.DataArray:
    total_precip_to_surface = ds.total_precipitation_rate * SECONDS_PER_DAY
    total_precip_to_surface.attrs = {
        "long_name": "total precip to surface (max(PRATE-<dQ2>-<nQ2>, 0))",
        "units": "mm/day",
    }
    return total_precip_to_surface.rename("total_precip_to_surface")



In [None]:
def open_diags(path, rename):
    sfc_ds = standardize_fv3_diagnostics(
        xr.open_zarr(fsspec.get_mapper(os.path.join(path, 'sfc_dt_atmos.zarr')), consolidated=True)
    )
    diags_ds = standardize_fv3_diagnostics(
        xr.open_zarr(fsspec.get_mapper(os.path.join(path, 'diags.zarr')), consolidated=True)
    )
    ds = xr.merge([sfc_ds, diags_ds, _total_precip_to_surface(diags_ds)])
    
    new_ds = {}
    for old, new in rename.items():
        new_ds[new] = ds[old]
        
    return xr.Dataset(new_ds)


def add_units(ds):
    for var in ds.data_vars:
        if 'precipitation' in var:
            ds[var].attrs.update({'units': 'mm/day'})
        else:
            ds[var].attrs.update({'units': 'W/m^2'})
    return ds


datasets = []
for case, url in CASES.items():
    print(case)
    ds = open_diags(url, plot_vars).resample(time="1D", label="left").mean()
    datasets.append(ds.assign_coords({'case': case}))
run_ds_2d = xr.concat(datasets, dim='case')

run_ds = {}
for var in run_ds_2d.data_vars:
    run_ds[var] = land_mean(run_ds_2d[var], mask, area)
run_ds = add_units(xr.Dataset(run_ds))

with ProgressBar():
    run_ds = run_ds.load()

In [None]:
verif_rename = {
    'precipitation rate': 'precipitation rate',
    'LHTFLsfc': 'latent heat flux',
    'SHTFLsfc': 'sensible heat flux',
    'DSWRFsfc': 'sfc. downward SW',
    'DLWRFsfc': 'sfc. downward LW'
}

def rename_verif(ds, rename):
    new_ds = {}
    for old, new in rename.items():
        new_ds[new] = ds[old]
    return xr.Dataset(new_ds)

def add_precipitation_rate(ds):
    seconds_per_day = 86400
    precipitation_rate = ds['PRATEsfc']*seconds_per_day
    precipitation_rate.attrs = {'long_name': 'precipitation rate', 'units': 'mm/day'}
    ds['precipitation rate'] = precipitation_rate
    return ds.drop_vars('PRATEsfc')


verif_ds_2d = standardize_fv3_diagnostics(catalog['40day_c48_gfsphysics_15min_may2020'].to_dask())
verif_ds_2d = add_precipitation_rate(verif_ds_2d)
verif_ds_2d = verif_ds_2d.resample(time="1D", label="left").mean()
verif_ds_2d = rename_verif(verif_ds_2d, verif_rename)
verif_ds = {}
for var in verif_ds_2d.data_vars:
    verif_ds[var] = land_mean(verif_ds_2d[var], mask, area)
verif_ds = add_units(xr.Dataset(verif_ds))
    
with ProgressBar():
    verif_ds = verif_ds.load()

In [None]:
combined_ds = xr.concat([run_ds, verif_ds.assign_coords({'case': 'verification'})], dim='case', join='inner')

In [None]:
labels = ['baseline', 'nudged', 'nudged w/ prescribed rad.+precip.', 'coarsened fine-res verif.']

def case_comparison_ts(da):
    fig, ax = plt.subplots(1,1)
    if da.name == 'sfc. downward SW':
        h = da.plot(ax=ax, hue='case', lw=2)
        ax.legend(h, labels)
    else:
        da.plot(ax=ax, hue='case', lw=2, add_legend=False)
    ax.set_xlabel(None)
    ax.set_ylabel(da.attrs.get('units'))
    ax.set_xlim([da.time.values[0], da.time.values[-4]])
    ax.set_ylim([0, np.nanmax(da.values)*1.1])
    ax.set_title(da.name)
    ax.grid(axis='y')
    plt.xticks(rotation = 45)
    fig.tight_layout()
    fig.set_size_inches([5, 6])
    fig.set_dpi(150)
    fig.savefig(os.path.join(OUTPUTDIR, f"{da.name}_land_bias_TS.png"), bbox_inches='tight')

In [None]:
for var in combined_ds.data_vars:
    case_comparison_ts(combined_ds[var])