### nudged training surface flux bias maps

Compare the effects of setting precipitation in the nudged run on time-mean surface turbulent flux biases

In [None]:
import xarray as xr
import numpy as np
import intake
import cftime
from matplotlib import pyplot as plt
from vcm.catalog import catalog
from vcm.fv3.metadata import standardize_fv3_diagnostics
import fv3viz as viz
import os
from dask.diagnostics import ProgressBar
import warnings

In [None]:
# paths and experiments

CASES = {
    'baseline': 'gs://vcm-ml-experiments/2021-04-13/baseline-physics-run-20160801.010000-start-rad-step-1800s',
    'nudging-only': 'gs://vcm-ml-experiments/2021-04-13-n2f-c3072/3-hrly-ave-control-30-min-rad-timestep-shifted-start-tke-edmf',
    'nudging-prescribed-rad-precip': 'gs://vcm-ml-experiments/2021-04-13-n2f-c3072/3-hrly-ave-rad-precip-setting-30-min-rad-timestep-shifted-start-tke-edmf'
}

verification_dycore_entry = '40day_c48_atmos_8xdaily_may2020'
verification_physics_entry = '40day_c48_gfsphysics_15min_may2020'
output_dir = 'figures'

In [None]:
# load run datasets

drop_vars = [
    'column_heating_nudge',
    'column_moistening_nudge',
    'column_mass_tendency_nudge',
    'net_moistening_due_to_nudging',
    'net_heating_due_to_nudging',
    'net_mass_tendency_due_to_nudging',
    'column_integrated_dQu',
    'column_integrated_dQv',
    'net_heating',
    'net_moistening',
    'total_precip',
    'total_precipitation_rate',
    'USWRFsfc_from_RRTMG',
    'DSWRFsfc_from_RRTMG',
    'DLWRFsfc_from_RRTMG'
]

run_datasets = []
for case, url in CASES.items():
    print(f"{case}: {url}")
    dycore = (
        standardize_fv3_diagnostics(intake.open_zarr(os.path.join(url, 'atmos_dt_atmos.zarr')).to_dask())
        .resample(time="1H", label="right").nearest()
    )
    physics = (
        standardize_fv3_diagnostics(intake.open_zarr(os.path.join(url, 'sfc_dt_atmos.zarr')).to_dask())
        .resample(time="1H", label="right").nearest()
    )
    diags = (
        standardize_fv3_diagnostics(intake.open_zarr(os.path.join(url, 'diags.zarr')).to_dask())
        .resample(time="1H", label="right").nearest()
    )
    merged = xr.merge([dycore, physics, diags]).drop_vars(names=drop_vars, errors='ignore')
    run_datasets.append(merged.assign_coords(cases=[case]))
run_dataset = xr.concat(run_datasets, dim='cases', join='inner')

run_dataset = run_dataset.loc[{'time': run_dataset.time > cftime.DatetimeJulian(2016, 8, 5, 1, 0, 0, 0)}]

In [None]:
# load verification dataset and subsample to hourly

verification_dycore = (
    standardize_fv3_diagnostics(catalog[verification_dycore_entry].to_dask())
    .resample(time="1H", label="right").nearest()
)
verification_physics = (
    standardize_fv3_diagnostics(catalog[verification_physics_entry].to_dask())
    .resample(time="1H", label="right").nearest()
)
# verification_physics = add_derived.physics_variables(verification_physics)
verification_dataset = xr.merge([verification_dycore, verification_physics])
verification_dataset = verification_dataset.rename({'TB': 'TMPlowest'})

In [None]:
# use common times for both runs and verif

common_time = verification_dataset.time.loc[verification_dataset.time > run_dataset.time[0]]
run_dataset = run_dataset.sel(time=common_time)
verification_dataset = verification_dataset.sel(time=common_time)

In [None]:
def add_total_downward_surface_radiative_flux(ds: xr.Dataset) -> xr.Dataset:
    DRFsfc = ds['DSWRFsfc'] + ds['DLWRFsfc']
    DRFsfc = DRFsfc.assign_attrs({
        'long_name': 'total downward radiative flux at surface',
        'units': 'W/m^2'
    })
    ds['DRFsfc'] = DRFsfc
    return ds

In [None]:
run_dataset = add_total_downward_surface_radiative_flux(run_dataset)
verification_dataset = add_total_downward_surface_radiative_flux(verification_dataset)

In [None]:
TIME_MEANS = [
    'LHTFLsfc',
    'SHTFLsfc',
    'DRFsfc',
    'DSWRFsfc',
    'DLWRFsfc',
]

def time_mean_variable(da: xr.DataArray) -> xr.DataArray:
    time_mean_da = da.mean(dim=['time'])
    time_mean_da.attrs = da.attrs
    return time_mean_da

def add_time_means(ds: xr.Dataset) -> xr.Dataset:
    for name in TIME_MEANS:
        try:
            ds = ds.assign({f'{name}_time_mean': time_mean_variable(ds[name])})
        except KeyError:
            print(f"Variable could not be computed: {name}")
    return ds

In [None]:
run_dataset = add_time_means(run_dataset)
verification_dataset = add_time_means(verification_dataset)

In [None]:
# open grid

grid_c48 = standardize_fv3_diagnostics(catalog["grid/c48"].to_dask())
ls_mask = standardize_fv3_diagnostics(catalog["landseamask/c48"].to_dask())
grid_c48 = xr.merge([grid_c48, ls_mask])

MAPPABLE_VAR_KWARGS = {
    "coord_x_center": "x",
    "coord_y_center": "y",
    "coord_x_outer": "x_interface",
    "coord_y_outer": "y_interface",
    "coord_vars": {
        "lonb": ["y_interface", "x_interface", "tile"],
        "latb": ["y_interface", "x_interface", "tile"],
        "lon": ["y", "x", "tile"],
        "lat": ["y", "x", "tile"],
    },
}

In [None]:
def prognostic_time_mean_maps(ds, var, title=None, plot_cube_kwargs=None, rms_mean=None, col_wrap=3, fig_size=[14, 4], dpi=96):
    plot_cube_kwargs = plot_cube_kwargs or {}
    rms_mean = rms_mean or {}
    fig, axes, _, _, facetgrid = viz.plot_cube(ds, col='cases', col_wrap=col_wrap, **plot_cube_kwargs)
    facetgrid.set_titles(template='')
    if rms_mean:
        for i, case in enumerate(ds.cases.values):
            if title:
                firstline = title
            else:
                firstline = case
            ax = axes.flatten()[i]
            if 'bias' in rms_mean and 'rmse' in rms_mean:
                ax.set_title(
                    f"{firstline}\n"
                    f"RMS: {rms_mean['rmse'][i]:3.1f}, mean: {rms_mean['bias'][i]:3.1f}"
                )
            elif 'bias' in rms_mean:
                ax.set_title(
                    f"{firstline}\n"
                    f"mean: {rms_mean['bias'][i]:3.1f}"
                )
    fig.set_size_inches(fig_size)
    fig.set_dpi(dpi)
    
def _var_rms(ds, area):
    weights = area/area.mean()
    return np.sqrt(((ds**2)*weights).mean(dim=['x', 'y', 'tile']))

def _var_mean(ds, area):
    weights = area/area.mean()
    return (weights*ds).mean(dim=['x', 'y', 'tile'])

def _var_land_mean(ds, area, mask):
    area = area.where(mask == 1.0)
    return (ds*area).sum(dim=['x', 'y', 'tile'])/(area.sum(dim=['x', 'y', 'tile']))

In [None]:
var = 'SHTFLsfc_time_mean'
diff_ds = (
    viz.mappable_var(
        xr.merge([(run_dataset - verification_dataset).sel({'cases': ['nudging-only', 'nudging-prescribed-rad-precip']}), grid_c48]),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)
with ProgressBar():
    diff_ds = diff_ds.load()
    rms_ = _var_rms(diff_ds[var], grid_c48['area']).load()
    mean_ = _var_mean(diff_ds[var], grid_c48['area']).load()
    land_mean_ = _var_land_mean(diff_ds[var], grid_c48['area'], grid_c48['land_sea_mask']).load()
    
diff_ds[var] = diff_ds[var].assign_attrs({
    'long_name': 'sensible heat flux bias, coarse - fine',
    'units': run_dataset[var].attrs["units"]
})
prognostic_time_mean_maps(
    diff_ds,
    var,
    rms_mean=dict(bias=land_mean_.values),
    plot_cube_kwargs=dict(vmin=-50, vmax=50),
    fig_size=[12, 3.5],
    col_wrap=2,
    dpi=150
)
plt.savefig(f"{output_dir}/sensible_heat_flux_bias_time_mean.png", facecolor='white', bbox_inches='tight')

In [None]:
var = 'LHTFLsfc_time_mean'
diff_ds = (
    viz.mappable_var(
        xr.merge([(run_dataset - verification_dataset).sel({'cases': ['nudging-only', 'nudging-prescribed-rad-precip']}), grid_c48]),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)
with ProgressBar():
    diff_ds = diff_ds.load()
    rms_ = _var_rms(diff_ds[var], grid_c48['area']).load()
    mean_ = _var_mean(diff_ds[var], grid_c48['area']).load()
    land_mean_ = _var_land_mean(diff_ds[var], grid_c48['area'], grid_c48['land_sea_mask']).load()
    
diff_ds[var] = diff_ds[var].assign_attrs({
    'long_name': 'latent heat flux bias, coarse - fine',
    'units': run_dataset[var].attrs["units"]
})
prognostic_time_mean_maps(
    diff_ds,
    var,
    rms_mean=dict(bias=land_mean_.values),
    plot_cube_kwargs=dict(vmin=-50, vmax=50),
    fig_size=[12, 3.5],
    col_wrap=2,
    dpi=150
)
plt.savefig(f"{output_dir}/latent_heat_flux_bias_time_mean.png", facecolor='white', bbox_inches='tight')
print(land_mean_)

In [None]:
var = 'DSWRFsfc'
i_time = 202
# time = run_dataset.time.isel(time=i_time).item()
diff_ds = (
    viz.mappable_var(
        xr.merge([(verification_dataset - run_dataset).isel(time=i_time).sel(cases=['nudging-only']), grid_c48]),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)

diff_ds[var] = diff_ds[var].assign_attrs({
    'long_name': 'surface downwelling shortwave bias, fine - coarse',
    'units': 'W/m^2'
})

with ProgressBar():
    diff_ds = diff_ds.load()
    rms_ = _var_rms(diff_ds[var], grid_c48['area']).load()
    mean_ = _var_mean(diff_ds[var], grid_c48['area']).load()
    land_mean_ = _var_land_mean(diff_ds[var], grid_c48['area'], grid_c48['land_sea_mask']).load()

prognostic_time_mean_maps(
    diff_ds,
    var,
    rms_mean=dict(bias=mean_.values),
    col_wrap=1,
    fig_size=[7, 4],
    dpi=150,
    title=run_dataset.time.isel(time=i_time).item().strftime('%Y%m%d.%H%M%S')
)
plt.savefig(f"{output_dir}/shortwave_bias_instantaneous.png", facecolor='white', bbox_inches='tight')
print(land_mean_.item())

In [None]:
var = 'DLWRFsfc'
i_time = 202
diff_ds = (
    viz.mappable_var(
        xr.merge([(verification_dataset - run_dataset).isel(time=i_time).sel(cases=['nudging-only']), grid_c48]),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)

diff_ds[var] = diff_ds[var].assign_attrs({
    'long_name': 'surface downwelling longwave bias, fine - coarse',
    'units': 'W/m^2'
})

with ProgressBar():
    diff_ds = diff_ds.load()
    rms_ = _var_rms(diff_ds[var], grid_c48['area']).load()
    mean_ = _var_mean(diff_ds[var], grid_c48['area']).load()
    land_mean_ = _var_land_mean(diff_ds[var], grid_c48['area'], grid_c48['land_sea_mask']).load()

prognostic_time_mean_maps(
    diff_ds,
    var,
    rms_mean=dict(bias=mean_.values),
    col_wrap=1,
    fig_size=[7, 4],
    dpi=150,
    title=run_dataset.time.isel(time=i_time).item().strftime('%Y%m%d.%H%M%S')
)
plt.savefig(f"{output_dir}/longwave_bias_instantaneous.png", facecolor='white', bbox_inches='tight')
print(land_mean_.item())

In [None]:
var = 'DRFsfc_time_mean'
diff_ds = (
    viz.mappable_var(
        xr.merge([grid_c48, (verification_dataset - run_dataset).sel({'cases': ['nudging-only']})], compat='override'),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)
diff_ds[var] = diff_ds[var].assign_attrs({
    'long_name': 'surface downwelling total radiation bias, fine - coarse',
    'units': run_dataset[var].attrs["units"]
})
with ProgressBar():
    diff_ds = diff_ds.load()
    rms_ = _var_rms(diff_ds[var], grid_c48['area']).load()
    mean_ = _var_mean(diff_ds[var], grid_c48['area']).load()
    land_mean_ = _var_land_mean(diff_ds[var], grid_c48['area'], grid_c48['land_sea_mask']).load()
    
prognostic_time_mean_maps(
    diff_ds,
    var,
    rms_mean=dict(bias=mean_.values),
    fig_size=[7, 4],
    col_wrap=1,
    dpi=150,
    title=run_dataset.time.isel(time=i_time).item().strftime('%Y%m%d.%H%M%S')
)
plt.savefig(f"{output_dir}/total_downwelling_bias_time_mean.png", facecolor='white', bbox_inches='tight')
print(land_mean_.item())

In [None]:
var = 'DLWRFsfc_time_mean'
diff_ds = (
    viz.mappable_var(
        xr.merge([grid_c48, (verification_dataset - run_dataset).sel({'cases': ['nudging-only']})], compat='override'),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)
diff_ds[var] = diff_ds[var].assign_attrs({
    'long_name': 'surface downwelling longwave bias, fine - coarse',
    'units': run_dataset[var].attrs["units"]
})
with ProgressBar():
    diff_ds = diff_ds.load()
    rms_ = _var_rms(diff_ds[var], grid_c48['area']).load()
    mean_ = _var_mean(diff_ds[var], grid_c48['area']).load()
    land_mean_ = _var_land_mean(diff_ds[var], grid_c48['area'], grid_c48['land_sea_mask']).load()
    
prognostic_time_mean_maps(
    diff_ds,
    var,
    rms_mean=dict(bias=mean_.values),
    fig_size=[7, 4],
    col_wrap=1,
    dpi=150,
    title=run_dataset.time.isel(time=i_time).item().strftime('%Y%m%d.%H%M%S')
)
plt.savefig(f"{output_dir}/longwave_bias_time_mean.png", facecolor='white', bbox_inches='tight')
print(land_mean_.item())

In [None]:
var = 'DSWRFsfc_time_mean'
diff_ds = (
    viz.mappable_var(
        xr.merge([grid_c48, (verification_dataset - run_dataset).sel({'cases': ['nudging-only']})], compat='override'),
        var,
        **MAPPABLE_VAR_KWARGS
    )
)
diff_ds[var] = diff_ds[var].assign_attrs({
    'long_name': 'surface downwelling shortwave bias, fine - coarse',
    'units': run_dataset[var].attrs["units"]
})
with ProgressBar():
    diff_ds = diff_ds.load()
    rms_ = _var_rms(diff_ds[var], grid_c48['area']).load()
    mean_ = _var_mean(diff_ds[var], grid_c48['area']).load()
    land_mean_ = _var_land_mean(diff_ds[var], grid_c48['area'], grid_c48['land_sea_mask']).load()
    
prognostic_time_mean_maps(
    diff_ds,
    var,
    rms_mean=dict(bias=mean_.values),
    fig_size=[7, 4],
    col_wrap=1,
    dpi=150,
    title=run_dataset.time.isel(time=i_time).item().strftime('%Y%m%d.%H%M%S')
)
plt.savefig(f"{output_dir}/shortwave_bias_time_mean.png", facecolor='white', bbox_inches='tight')
print(land_mean_.item())