## Plot maps of `TMP2m` trend for both SHiELD/ERA5 and corresponding ACE models

In [None]:
from collections import namedtuple

import xarray as xr
import numpy as np
from matplotlib import pyplot as plt
from cartopy import crs as ccrs

from utils import get_beaker_dataset_variables
from constants import WANDB_ID_FILE

import yaml
plt.rcParams['figure.dpi'] = 300
xr.set_options(keep_attrs=True)

In [None]:
with open(WANDB_ID_FILE, "r") as f:
    wandb_ids = yaml.safe_load(f)


In [None]:
EvalVar = namedtuple("EvalVar", ("name", "long_name", "units"))

EVAL_VARS = [
    EvalVar("TMP2m", "2-meter\nair temperature", "K"),
]
DS_VARS = [x.name for x in EVAL_VARS]

In [None]:
ACE2_SHIELD_RUN = "shield-amip-1deg-ace2-inference-81yr-IC0"
ACE2_ERA5_RUN = "era5-co2-81yr-RS2-IC0-monthly-output"
WANDB_RUNS = {
    "ACE2_SHIELD": wandb_ids[ACE2_SHIELD_RUN],
    "ACE2_ERA5": wandb_ids[ACE2_ERA5_RUN],
}

In [None]:
def compute_slope_intercept(prediction, years):
    return np.polyfit(years, prediction, 1)

def compute_decadal_trend(prediction, years, **ufunc_kwargs):
    result = xr.apply_ufunc(
        compute_slope_intercept,
        prediction,
        years,
        input_core_dims=[
            ("year",),
            ("year",),
        ],
        output_core_dims=[
            ("degree",),
        ],
        vectorize=True,
        **ufunc_kwargs
    )
    decadal_trend = 10 * result.isel(degree=0)
    intercept = result.isel(degree=1)
    start_year, end_year = years.values[0], years.values[-1]
    for name in decadal_trend.data_vars:
        decadal_trend[name].attrs["long_name"] = f"{start_year}-{end_year} trend of {prediction[name].attrs['long_name']}"
        decadal_trend[name].attrs["units"] = f"{prediction[name].attrs['units']} / decade"
    return decadal_trend, intercept

In [None]:
# download annual- and global-mean time series from beaker
datasets = []
for name, run in WANDB_RUNS.items():
    tmp = get_beaker_dataset_variables(run, 'monthly_mean_predictions.nc', DS_VARS)
    prediction_dataset = tmp.squeeze().isel(time=slice(None, -3)).groupby('valid_time.year').mean()
    tmp = get_beaker_dataset_variables(run, 'monthly_mean_target.nc', DS_VARS)
    target_dataset = tmp.squeeze().isel(time=slice(None, -3)).groupby('valid_time.year').mean()
    dims = {"dataset": [name]}
    datasets.append(prediction_dataset.expand_dims(dims | {"source": ["prediction"]}))
    datasets.append(target_dataset.expand_dims(dims | {"source": ["target"]}))
annual_ds = xr.merge(datasets)

# fixing issue where units are in long name
annual_ds['TMP2m'].attrs['long_name'] = "2m air temperature"


In [None]:
# for computing global means
weights = np.cos(np.deg2rad(annual_ds.lat))

# compute spatial maps of decadal trends in each dataset
trend_start_year = 1940
trend_end_year = 2020
tmp = annual_ds.sel(year=slice(trend_start_year, trend_end_year))
prediction_trends, _ = compute_decadal_trend(tmp, tmp.year)

# compute global- and annual-mean series
global_annual_ds = annual_ds.weighted(weights).mean(dim=['lat', 'lon'])

In [None]:
global_annual_ds.TMP2m.plot(hue='source', col='dataset')

In [None]:
plotme = prediction_trends.TMP2m
fg = plotme.plot(
    col='dataset',
    row='source',
    vmin=-0.5,
    vmax=0.5,
    cmap='RdBu_r',
    transform=ccrs.PlateCarree(),
    subplot_kws=dict(projection=ccrs.Robinson(central_longitude=180)),
)
fg.set_titles(template="")
titles = [['ACE2-ERA5', 'ACE2-SHiELD'], ['ERA5', 'SHiELD reference']]
for i in range(len(fg.axs)):
    for j in range(len(fg.axs[0])):
        global_mean = plotme.isel(source=i, dataset=j).weighted(weights).mean().item()
        global_mean_str = f"{global_mean:0.2f} K$\,$/$\,$decade"
        fg.axs[i, j].set_title(titles[i][j] + "\n" + global_mean_str, fontsize=9)
        fg.axs[i, j].coastlines(linewidth=0.5, color='grey')

fig = fg.fig
fig.set_size_inches(6, 3.7)
fig.savefig("figures/climate_skill_1deg_shield_and_era5_2m_trends.png", dpi=300, bbox_inches='tight', transparent=True)