### nudging tendencies transects

Make transects of nudging and physics tendencies from a coarse nudged run and compare to the fine-res apparent sources

In [None]:
%run predictive_mapper.py
import os
import intake
import xarray as xr
import numpy as np
import fsspec
import yaml
from datetime import timedelta
from dask.diagnostics import ProgressBar
from matplotlib import rc, pyplot as plt
import matplotlib.animation as animation
rc('animation', html='html5')
import matplotlib as mpl
from IPython.display import HTML
import fv3fit
from fv3fit._shared import EnsembleModel
import loaders
from vcm.catalog import catalog as CATALOG
from vcm.safe import get_variables
from vcm.calc import thermo
from vcm.fv3.metadata import standardize_fv3_diagnostics
from vcm import encode_time, interpolate_unstructured, convert_timestamps
from loaders.mappers import GeoMapper
from typing import Mapping, Sequence, Union, Tuple
import string
import warnings
warnings.filterwarnings('ignore')

In [None]:
mpl.rcParams.update(**{
    'font.size': 8,
    'xtick.labelsize': 'small',
    'ytick.labelsize': 'small',
    'axes.labelsize': 'small',
})

In [None]:
N2F_TRAIN_TEST_DATA_URL = 'gs://vcm-ml-experiments/2021-04-13-n2f-c3072/3-hrly-ave-rad-precip-setting-30-min-rad-timestep-shifted-start-tke-edmf-3-hrly-ave-physics-tendencies'
N2F_MODEL_URL = {
    '$TqR$-RF':'gs://vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/rf/trained_models/postphysics_ML_tendencies',
    '$TqR$-NN':'gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/trained_models/dq1-dq2',
} 
PLOT_VARS=[
    'pressure',
    '{var}_tendency_due_to_nudging',
    '{var}',
    '{var}_tendency_due_to_ML'
]
STATE_VARS = [
    'pressure_thickness_of_atmospheric_layer',
    'vertical_thickness_of_atmospheric_layer',
    'surface_geopotential',
    'specific_humidity',
    'air_temperature',
]
Q1_SCALE = dict(vmin=-5, vmax=5)
TIMESTEP_SECONDS = 10800
SECONDS_PER_DAY = 86400
G_PER_KG = 1000
INSTANTANEOUS_TIMESTAMP = '20160904.143000'
OUTDIR = 'figures'

In [None]:
def tendency_units(ds):
    for var in ds.data_vars:
        ds[var] = ds[var] * TIMESTEP_SECONDS
        if "specific_humidity" in var:
            ds[var] = ds[var] * G_PER_KG
    return ds

In [None]:
def time_average(ds, freq='3H'):
    ds = ds.resample({'time': freq}, base=1, loffset=timedelta(minutes=90)).mean()
    return ds

In [None]:
# load nudge-to-fine run dataset
nudging_tendencies = tendency_units(standardize_fv3_diagnostics(
    intake.open_zarr(os.path.join(N2F_TRAIN_TEST_DATA_URL, 'nudging_tendencies.zarr'), consolidated=True).to_dask()
))
physics_tendencies = tendency_units(standardize_fv3_diagnostics(
    intake.open_zarr(os.path.join(N2F_TRAIN_TEST_DATA_URL, 'physics_tendencies.zarr'), consolidated=True).to_dask()
))
states = time_average(standardize_fv3_diagnostics(
    get_variables(
        intake.open_zarr(os.path.join(N2F_TRAIN_TEST_DATA_URL, 'state_after_timestep.zarr'), consolidated=True).to_dask(),
        STATE_VARS
    )
))   
run_dataset = xr.merge([nudging_tendencies, physics_tendencies, states], join='inner')

In [None]:
def add_heights_and_pressures(ds: xr.Dataset) -> xr.Dataset:
    
    ds = ds.assign({
        'pressure': 0.01*thermo.pressure_at_midpoint_log(
            ds['pressure_thickness_of_atmospheric_layer'],
            dim='z'
        ).assign_attrs({
            'long_name': "pressure at layer center",
            'units': 'hPa'
        })
    })
    
    ds = ds.assign({
        'pressure_interface': 0.01*thermo.pressure_at_interface(
            ds['pressure_thickness_of_atmospheric_layer'],
            dim_center='z',
            dim_outer='zb'
        ).assign_attrs({
            'long_name': "pressure at layer interface",
            'units': 'hPa'
        })
    })
    
    ds = ds.assign({
        'height': thermo.height_at_midpoint(
            ds['vertical_thickness_of_atmospheric_layer'],
            ds['surface_geopotential'],
            dim='z'
        ).assign_attrs({
            'long_name': "height at layer center",
            'units': 'm'
        })
    })
    
    ds = ds.assign({
        'height_interface': thermo.height_at_interface(
            ds['vertical_thickness_of_atmospheric_layer'],
            ds['surface_geopotential'],
            dim_center='z',
            dim_outer='zb'
        ).assign_attrs({
            'long_name': "height at layer interface",
            'units': 'm'
        })
    })
    
    return ds

In [None]:
run_dataset = add_heights_and_pressures(run_dataset)

In [None]:
# open grid and merge, overriding grid vars in fortran diagnostics because repeated values in time mess with the kdtree
grid_c48 = standardize_fv3_diagnostics(CATALOG["grid/c48"].to_dask())
run_dataset = xr.merge([grid_c48, run_dataset], compat='override')

In [None]:
# load the models for making ML predictions

models = {}
for ml_type, url in N2F_MODEL_URL.items():
    if 'RF' in ml_type:
        models[ml_type] = fv3fit.load(url)
    else:
        # have to go around fv3fit loading the ensemble model because of bug
        with fsspec.open(os.path.join(url, 'ensemble_model.yaml'), "r") as f:
            config = yaml.safe_load(f)
        ensemble_members = [fv3fit.load(path) for path in config["models"]]
        reduction = config["reduction"]
        models[ml_type] = EnsembleModel(ensemble_members, reduction)

In [None]:
nudging_variables = [
    "air_temperature",
    "specific_humidity",
    "x_wind",
    "y_wind",
    "pressure_thickness_of_atmospheric_layer"
]

predictive_mapper = PredictiveMapper(
    models,
    N2F_TRAIN_TEST_DATA_URL,
    loaders.mappers.open_nudge_to_fine,
    {'nudging_variables': nudging_variables},
    grid_c48
)

In [None]:
class TendencyCrossSectionMapper(GeoMapper):
    
    def __init__(
        self,
        dataset: xr.Dataset,
        preditive_mapping: Mapping[str, xr.Dataset],
        transect: Mapping[str, xr.Dataset],
        primary_var: str,
        other_vars: Sequence[str],
        title: str,
        units: str='kg/kg/s',
        scale: float=25
    ):
        self._dataset = dataset
        self._predictive = preditive_mapping
        self.primary_var = primary_var
        self._other_vars = other_vars
        self._transect = transect
        self.title = title
        self.units = units
        self.scale = scale
        
        times = self._dataset.time.values.tolist()
        time_strings = [encode_time(time) for time in times]
        self._time_lookup = dict(zip(time_strings, times))
        
        
    def __getitem__(self, key: Union[str, slice]) -> xr.Dataset:
        
        if isinstance(key, str):
            ds = self._get_timestep(key)
        elif isinstance(key, slice):
            ds = self._get_timeslice(key)
        ds = get_variables(ds, [self.primary_var] + self._other_vars)
        with ProgressBar():
            print(f'Loading data for {key}')
            return ds.load()
        
    def _get_timestep(self, key: str) -> Tuple[xr.Dataset, xr.Dataset]:
        predicted = tendency_units(self._predictive[key])
        nudged = (
            self._dataset
            .sel(time=self._time_lookup[key])
            .drop_vars('time').expand_dims({'time': [key]})
        )
        ds = xr.merge([nudged, predicted], compat='override')
        return interpolate_unstructured(ds, self._transect).expand_dims({'time': ds.time})
    
    def _get_timeslice(self, key_slice: slice) -> Tuple[xr.Dataset, xr.Dataset]:
        timesteps = sorted(self._time_lookup.keys())
        start = timesteps.index(key_slice.start) if key_slice.start is not None else None
        end = timesteps.index(key_slice.stop) if key_slice.stop is not None else None
        step = key_slice.step if key_slice.step is not None else None
        timestep_subset = timesteps[start:end:step]
        predicted = tendency_units(xr.concat(
            [self._predictive[timestep] for timestep in timestep_subset],
            dim='time'
        ))
        nudged = (
            self._dataset
            .isel(time=slice(start, end, step))
            .assign_coords({'time': timestep_subset})
        )
        ds = xr.merge([nudged, predicted], compat='override')
        return interpolate_unstructured(ds, self._transect)
        
    def keys(self):
        return self._time_lookup.keys()

In [None]:
def _transect_frame(time, ds, axes, var, pcolor_kwargs, title, yscale):

    LEVELS, UNITS_CONV, NUM, LETTER = (
        (range(0, 50, 5), G_PER_KG, '2', 'q')
        if var == 'specific_humidity' else (range(200, 300, 20), 1, '1', 'T')
    )
    ds = ds.sel(time=time)
    
    ax0 = axes[0]
    ax0.clear()
    ax0.pcolormesh(
        ds['transect'],
        ds['pressure'],
        ds[f'{var}_tendency_due_to_nudging'],
        shading='nearest',
        **pcolor_kwargs
    )
    ax0.set_xlabel('latitude')
    ax0.set_ylim([1e3, 2e2])
    ax0.set_ylabel('pressure [hPa]')
    ax0.set_facecolor('olive')
    ax0.set_title(f'a) nudging ($\Delta Q_{{{LETTER}}})$')
    derivations = [derivation for derivation in ds.derivation.values if derivation != 'target']
    for i, derivation in enumerate(derivations):
        axi = axes[1 + i]
        axi.clear()
        axi.pcolormesh(
            ds['transect'],
            ds['pressure'],
            ds[f'{var}_tendency_due_to_ML'].sel(derivation=derivation),
            shading='nearest',
            **pcolor_kwargs
        )
        axi.invert_yaxis()
        axi.set_ylim([1e3, 2e2])
        axi.tick_params(labelleft=False)
        axi.set_facecolor('olive')
        axi.set_title(f'{string.ascii_lowercase[1 + i]}) {derivation} ($\Delta Q_{{{LETTER}}}^{{ML}}$)')
        axi.set_xlabel('latitude')

In [None]:
def get_frame_axes(pcolor_kwargs, primary_var, units, time):
    fig = plt.figure()
    axes = list(fig.subplots(1, 3, sharex=True).flatten())
    norm = mpl.colors.Normalize(vmin=pcolor_kwargs['vmin'], vmax=pcolor_kwargs['vmax'])
    sm = plt.cm.ScalarMappable(cmap=pcolor_kwargs['cmap'], norm=norm)
    sm.set_array([])
    plt.colorbar(
        sm,
        ax=axes,
        label=f"{primary_var.replace('_', ' ')} tendency at {time} [{units}]",
        location='bottom',
        aspect=50
    )
    return fig, axes

In [None]:
def instantaneous_transect(transect_mapper, time, fig_size=[10, 8], pcolor_kwargs=None):

    pcolor_kwargs = pcolor_kwargs or {}
    if 'vmin' not in pcolor_kwargs:
        pcolor_kwargs['vmin'] = -2
    if 'vmax' not in pcolor_kwargs:
        pcolor_kwargs['vmax'] = 2
    if 'cmap' not in pcolor_kwargs:
        pcolor_kwargs['cmap'] = 'seismic'

    fig, axes = get_frame_axes(pcolor_kwargs, transect_mapper.primary_var, transect_mapper.units, time)

    ds = transect_mapper[time]
    _transect_frame(time, ds, axes, transect_mapper.primary_var, pcolor_kwargs, transect_mapper.title, transect_mapper.scale)

    fig.set_size_inches(fig_size)
    fig.savefig(f"{OUTDIR}/Figure_4_{transect_mapper.title.replace(' ', '_')}_{time}.eps", bbox_inches='tight', facecolor='white')

In [None]:
def meridional_transect(lon, lat_start, lat_stop, lat_res):
    lat = np.arange(lat_start, lat_stop, lat_res)
    lon = np.ones_like(lat) * lon
    return {
        'lat': xr.DataArray(lat, dims=['transect'], coords={'transect': lat}),
        'lon': xr.DataArray(lon, dims=['transect'], coords={'transect': lat})
    }

def zonal_transect(lat, lon_start, lon_stop, lon_res):
    lon = np.arange(lon_start, lon_stop, lon_res)
    lat = np.ones_like(lon) * lat
    return {
        'lat': xr.DataArray(lat, dims=['transect'], coords={'transect': lon}),
        'lon': xr.DataArray(lon, dims=['transect'], coords={'transect': lon})
    }

In [None]:
sahara_transect = meridional_transect(0, 0, 42, 1)
var = 'specific_humidity'
sahara_moisture_tendency_xs_mapper = TendencyCrossSectionMapper(
    run_dataset,
    predictive_mapper,
    sahara_transect,
    var,
    [var_template.format(var=var, num=VAR_MAPPING[var]) for var_template in PLOT_VARS],
    'Sahara moisture transect',
    units='g/kg/3-hr'
)

In [None]:
instantaneous_transect(sahara_moisture_tendency_xs_mapper, INSTANTANEOUS_TIMESTAMP, fig_size=[7.6, 3.74])

In [None]:
# this converts matplotlib eps files to a more manageable size

!epstopdf figures/Figure_4_Sahara_moisture_transect_20160904.143000.eps
!pdftops -eps figures/Figure_4_Sahara_moisture_transect_20160904.143000.pdf
!rm figures/Figure_4_Sahara_moisture_transect_20160904.143000.pdf