{{ prolog }}

# Calibrating MESMER on multiple scenarios
This tutorial shows how to calibrate the parameters for MESMER on an example dataset of coarse regridded ESM output for multiple climate change scenarios. We calibrate the parameters for MESMER using three scenarios: a historical, a low emission (SSP1-2.6), and a high emission (SSP5-8.5) scenario, where SSP5-8.5 includes several ensemble members. You can find the basics of the MESMER approach in Beusch et al. ([2020](https://doi.org/10.5194/ESD-11-139-2020)) and the multi-sceario approach in Beusch et al. ([2022](https://doi.org/10.5194/gmd-15-2085-2022)). Training MESMER consists of four steps:

- **global trend**: compute the global temperature trend, including the volcanic influence on historical trends
- **global variablity**: estimating the parameters to generate global variability
- **local trend**: estimate parameters to translate global mean temperature (including global variability) into local temperature
- **local variability**: estimate parameters needed to generate local variability

This example can be extended to more scenarios, ensemble members and higher resolution data. See also the mesmer calibration test in *tests/integration/*.

In [None]:
import pathlib

import cartopy.crs as ccrs
import filefisher
import matplotlib.pyplot as plt
import xarray as xr

import mesmer

## Load data

MESMER expects a specific data format. Data from each scenario should be a node (or group) on an `xr.DataTree` (more on this below) e.g.:

```
<xarray.DataTree>
Group: /
├── Group: /historical
|    ...
├── Group: /ssp126
|    ...
```

Each scenario is a `xr.Dataset` with 4 dimensions: `member`, `time`, `lat`, `lon`. Below we show one way to load data such that it conforms to the desired format. We load data from the cmip6-ng ("new generation") repository. This data has undergone a small reformatting from the original cmip6 archive. For the sake of computational speed we also load data which has been regridded to a coarse resolution. Loading the data can be adapted to the data format you are most used to - as long as the final output has the desired format.

---

MESMER is Earth System Model specific, aiming to reproduce the behaviour of one ESM. Here we train on the CMIP6 output of the model IPSL-CM6A-LR.

In [None]:
model = "IPSL-CM6A-LR"

We use the library [*filefisher*](https://github.com/mpytools/filefisher) to search all files in the cmip6-ng archive for the model and scenarios we want to use. Filefisher can search through paths for given file patterns. It returns all paths matching the pattern such that you can load the files in the next step.

Here, we want to find all files that have data for annual near surface temperature (`"tas"`) for the used model and the future scenarios ssp126 and ssp585. Next, we search for the historical data that match the members found for the two future scenarios.

In [None]:
# mesmer provides example data under "./data/cmip6-ng"
cmip_data_path = mesmer.example_data.cmip6_ng_path(relative=True)

CMIP_FILEFINDER = filefisher.FileFinder(
    path_pattern=cmip_data_path / "{variable}/{time_res}/{resolution}",
    file_pattern="{variable}_{time_res}_{model}_{scenario}_{member}_{resolution}.nc",
)
CMIP_FILEFINDER

Search data for ssp126 and ssp585 - we find one and two ensemble members, respectively:

In [None]:
scenarios = ["ssp126", "ssp585"]

keys = {"variable": "tas", "model": model, "resolution": "g025", "time_res": "ann"}

fc_scens = CMIP_FILEFINDER.find_files(scenario=scenarios, keys=keys)
fc_scens.df

We also need to find the same ensemble members in the historical data, such that we end up with five files we need to load:

In [None]:
# get the historical members that are also in the future scenarios, but only once
members = fc_scens.df.member.unique()

fc_hist = CMIP_FILEFINDER.find_files(scenario="historical", member=members, keys=keys)

fc_all = fc_hist.concat(fc_scens)
fc_all.df

Now we load all the files we found into a ``DataTree``, a data structure provided by [xarray](https://docs.xarray.dev/en/stable/index.html). It is a container to hold xarray `Dataset` objects that are not alignable. This is useful for us since we have historical and future data, which have different time coordinates. Moreover, the scenarios may also have different numbers of members (as e.g., SSP1-2.6, which only has one). Thus, we store the data of each scenario in a `Dataset` with all its ensemble members along a `member` dimension. Then we store all the scenario datasets in one `DataTree` node. The `DataTree` allows us to perform computations on each of the scenarios separately.

We define a helper function to load the data from the cmip6_ng example data repository:

In [None]:
def load_data(filecontainer):

    out = xr.DataTree()

    scenarios = filecontainer.df.scenario.unique().tolist()

    # load data for each scenario
    for scen in scenarios:
        files = filecontainer.search(scenario=scen)

        # load all members for a scenario
        members = []
        for fN, meta in files.items():
            time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
            ds = xr.open_dataset(fN, decode_times=time_coder)
            # drop unnecessary variables
            ds = ds.drop_vars(["height", "time_bnds", "file_qf"], errors="ignore")
            # assign member-ID as coordinate
            ds = ds.assign_coords({"member": meta["member"]})
            members.append(ds)

        # create a Dataset that holds each member along the member dimension
        scen_data = xr.concat(members, dim="member")
        # put the scenario dataset into the DataTree
        out[scen] = xr.DataTree(scen_data)

    return out

In [None]:
dt = load_data(fc_all)
dt

This results in the data format discussed above. You can examine it by clicking on `Groups` above. 

---

We will need some configuration parameters in the following:
1. ``THRESHOLD_LAND``: threshold above which land fraction to consider a grid point as a land grid point.
2. ``REFERENCE_PERIOD``: we will work not with absolute temperature values but with temperature anomalies w.r.t. a reference period

In [None]:
THRESHOLD_LAND = 1 / 3
REFERENCE_PERIOD = slice("1850", "1900")

## Calculate anomalies

In [None]:
# calculate anomalies w.r.t. the reference period
tas_anom = mesmer.anomaly.calc_anomaly(dt, reference_period=REFERENCE_PERIOD)

## Global mean

In [None]:
# calculate global mean
tas_globmean = mesmer.weighted.global_mean(tas_anom)
tas_globmean

To estimate the **global trend** and **global variability** we split the global mean temperature signal into a trend and variability component:
$T_{t}^{glob} = T_{t}^{glob,\,trend} + T_{t}^{glob,\,var}$.
The trend component is further split into a smooth and volcanic component:
$T_{t}^{glob,\,trend} = T_{t}^{glob,\,smooth} + T_{t}^{glob,\,volc}$.


## "Smooth" and volcanic components of the global temperature
The volcanic contributions to the global mean temperature trend of the historical period have to be removed to estimate the linear regression of global mean temperature to local temperature.

1. Calculate $T_{t}^{glob,\,smooth}$ using a lowess smoother, with 50 time steps:

In [None]:
# mean over members before smoothing
tas_globmean_ensmean = tas_globmean.mean(dim="member")

n_steps = 50

tas_globmean_smoothed = mesmer.stats.lowess(
    tas_globmean_ensmean,
    dim="time",
    n_steps=n_steps,
    use_coords=False,
)

In [None]:
# plot historical
f, ax = plt.subplots()

h0, *_ = tas_globmean["historical"].tas.plot.line(ax=ax, x="time", color="grey", lw=1)
a2, *_ = tas_globmean_smoothed["historical"].tas.plot.line(ax=ax, x="time", lw=2)

ax.legend([h0, a2], ["Ensemble members", "Smooth ensemble mean"])

2. Fit the parameter of the volcanic contributions only on the historical smoothed data of all ensemble members available. The future scenarios do not have volcanic contributions.

In [None]:
hist_tas_residuals = tas_globmean["historical"] - tas_globmean_smoothed["historical"]

# fit volcanic influence
volcanic_params = mesmer.volc.fit_volcanic_influence(hist_tas_residuals.tas)

volcanic_params.aod

3. Superimpose the volcanic influence on the historical time series. Because the historical data is treated as its own scenario, we encounter discontinuities at the boundary between historical and future period. However, this is not relevant for the fitting of the parameters hereafter.

In [None]:
# superimpose the volcanic forcing on historical data
tas_globmean_smoothed["historical"] = mesmer.volc.superimpose_volcanic_influence(
    tas_globmean_smoothed["historical"],
    volcanic_params,
)

In [None]:
# plot global mean time series
f, ax = plt.subplots()

# plot unsmoothed global means
tas_globmean["historical"].tas.plot.line(
    ax=ax, lw=1, x="time", color="0.5", add_legend=False
)
tas_globmean["ssp126"].tas.plot.line(
    ax=ax, lw=1, x="time", color="#6baed6", add_legend=False
)
tas_globmean["ssp585"].tas.plot.line(
    ax=ax, lw=1, x="time", color="#fc9272", add_legend=False
)

# plot smoothed global means including volcanic influence for historical
tas_globmean_smoothed["historical"].tas.plot.line(
    ax=ax, lw=1.5, x="time", color="0.1", label="historical"
)
tas_globmean_smoothed["ssp126"].tas.plot.line(
    ax=ax, lw=1.5, x="time", color="#08519c", label="ssp126"
)
tas_globmean_smoothed["ssp585"].tas.plot.line(
    ax=ax, lw=1.5, x="time", color="#de2d26", label="ssp585"
)

# histend = tas_globmean["historical"].time.isel(time=-1).item()
# ax.axvline(histend, color="0.4")
ax.axhline(0, color="0.1", lw=0.5)

ax.set_title("")
plt.legend(loc="upper left")

4. Calculate residuals (w.r.t. smoothed ts) i.e. remove the smoothed global mean, including the volcanic influence from the anomalies.

In [None]:
tas_globmean_resids = tas_globmean - tas_globmean_smoothed
# rename to tas_resids
tas_globmean_resids = mesmer.datatree.map_over_datasets(
    lambda ds: ds.rename({"tas": "tas_resids"}), tas_globmean_resids
)

In [None]:
# plot residuals
h0, *_ = tas_globmean_resids["historical"].tas_resids.plot.line(
    x="time", color="0.5", lw=1, add_legend=False
)
h1, *_ = tas_globmean_resids["ssp126"].tas_resids.plot.line(
    x="time", color="#08519c", lw=1, add_legend=False
)
h2, *_ = tas_globmean_resids["ssp585"].tas_resids.plot.line(
    x="time", color="#de2d26", lw=1, add_legend=False
)

plt.title("Residuals")
plt.axhline(0, lw=1, color="0.1")

plt.legend([h0, h1, h2], ["historical", "ssp126", "ssp585"])

## Global variability

In this step we want to fit an AR process for estimating global variability, taking in the residual global mean temperature as follows:

$$T_{t}^{glob,\,var} = \alpha_0 + \sum\limits_{k=1}^{p} \alpha_k \cdot T_{t-k}^{glob,\,var} + \varepsilon_t,\ \varepsilon_t \sim \mathcal{N}(0, \sigma)$$


We first estimate the order of the AR process and then fit the parameters. Internally, we fit the parameters for each member and then average first over the parameters of each scenario and then over all scenarios to arrive at a single set of parameters.

In [None]:
ar_order = mesmer.stats.select_ar_order_scen_ens(
    tas_globmean_resids, dim="time", ens_dim="member", maxlag=12, ic="bic"
)

global_ar_params = mesmer.stats.fit_auto_regression_scen_ens(
    tas_globmean_resids, dim="time", ens_dim="member", lags=ar_order
)

global_ar_params = global_ar_params.drop_vars("nobs")
global_ar_params

## Local forced response
Now we need to estimate how the global trend translates into a local forced response. This is done using a linear regression of the global trend and the global variability as predictors:

$T_{s,t}^{resp} = \beta_s^{int} + \beta_s^{trend} \cdot T_t^{glob,\,trend} + \beta_s^{var} \cdot T_t^{glob,\,var}$

To this end, we stack all values (members, scenarios) into a single dataset, the only important thing is that predictor and predicted values stay together.

Before computing the coefficients we need to prepare the local temperature data:

1. Mask out ocean grid points (where the land fraction is larger than `THRESHOLD_LAND`)
2. Mask out Antarctica
3. Convert the data from a 2D lat-lon grid to a 1D grid by stacking it and removing all gridcells that were previously masked out.



Before stacking, we extract the original grid. We need to save this together with the parameters to later be able to reconstruct the original grid from the gridpoints.

In [None]:
# extract original grid
grid_orig = tas_anom["historical"].ds[["lat", "lon"]]

In [None]:
def mask_and_stack(dt, threshold_land):
    dt = mesmer.mask.mask_ocean_fraction(dt, threshold_land)
    dt = mesmer.mask.mask_antarctica(dt)
    dt = mesmer.grid.stack_lat_lon(dt)
    return dt


# mask and stack the data
tas_stacked = mask_and_stack(tas_anom, THRESHOLD_LAND)

In [None]:
tas_stacked["ssp585"].tas.isel(member=1).plot()

We have now converted the 3D field (with dimensions lat, lon, and time) to a 2D field (with dimensions gridcell and time).

We create a new `DataTree` from all predictors - here the smoothed global mean and it's residuals. We could add more predictors, e.g. the squared temperatures or the ocean heat uptake:

In [None]:
predictors = mesmer.datatree.merge([tas_globmean_smoothed, tas_globmean_resids])

target = tas_stacked.copy()

In the linear regression, we want to weight the values of the different scenarios equally, i.e. we do not want scenarios with more members (here ssp585) be overrepresented in the linear regression parameters. Thus, we generate weights that weigh each value by the number of members in their scenario, so $w_{scen, mem, ts} = 1 / n\_mem_{scen}$. We do currently not take different number of timesteps (historical vs. scenario) into account:

In [None]:
# create weights
weights = mesmer.weighted.equal_scenario_weights_from_datatree(tas_stacked)
weights

We pool the different scenarios, ensemble members and timesteps into one sample dimension (containing time, member, and scenario as coordinates) for the linear regression. We want one `DataArray` per predictor and the target such that each sample of the predictor variables aligns with the corresponding sample of the target:

In [None]:
predictors_pooled, target_pooled, weights_pooled = (
    mesmer.datatree.broadcast_and_pool_scen_ens(predictors, target, weights)
)

target_pooled

In the linear regression, the predictors for each sample are used for every gridpoint of the target. We can now fit the linear regression:

In [None]:
local_lin_reg = mesmer.stats.LinearRegression()

local_lin_reg.fit(
    predictors=predictors_pooled,
    target=target_pooled.tas,
    dim="sample",
    weights=weights_pooled.weights,
)

local_forced_response_params = local_lin_reg.params
local_forced_response_params

In [None]:
data_vars = (
    "intercept",
    "tas",
    "tas_resids",
)

f, axs = plt.subplots(
    3, 1, sharex=True, sharey=True, subplot_kw={"projection": ccrs.Robinson()}
)
axs = axs.flatten()

for ax, data_var in zip(axs, data_vars):

    da = local_forced_response_params[data_var]
    da = mesmer.grid.unstack_lat_lon_and_align(da, grid_orig)

    h = da.plot(
        ax=ax,
        label=data_var,
        robust=True,
        center=0,
        extend="both",
        add_colorbar=False,
        transform=ccrs.PlateCarree(),
    )

    ax.set_extent((-180, 180, -60, 85), ccrs.PlateCarree())
    cbar = plt.colorbar(h, ax=ax, extend="both", pad=0.025)  # , shrink=0.7)
    ax.set(title=data_var, xlabel="", ylabel="", xticks=[], yticks=[])
    ax.coastlines()

## Local variability
Next we to fit the parameters for the local AR(1) process with a spatially correlated noise term used to emulate local variability:

$\eta_{s,\,t} = \gamma_{0,\,s} + \gamma_{1,\,s} \cdot \eta_{s,\,t-1} + \nu_{s,\,t}, \ \nu_{s,\,t} \sim \mathcal{N}(0, \Sigma_{\nu}(r))$

The first component which contains the AR parameters ($\gamma_{0,\,s} + \gamma_{1,\,2} \cdot \eta_{s,\,t-1}$) ensures temporal correlation of the local variability whereas the noise term $\nu_{s,\,t}$ ensures spatial consistency. The covariance matrix $\Sigma_{\nu}(r)$ is estimated on the whole grid and represents the spatial correlation of temperatures between the different gridpoints. 

### Estimate the AR parameters

First we need to compute the residuals after the linear regression.

In [None]:
resids = local_lin_reg.residuals(predictors, target)

The local AR(1) process is estimated on the individual scenarios, but the covariance is estimated on the pooled residuals - therefore we need to have them in both forms.

In [None]:
resids_pooled = mesmer.datatree.pool_scen_ens(resids)

In [None]:
# fit the AR(1) process
local_ar = mesmer.stats.fit_auto_regression_scen_ens(
    resids,
    ens_dim="member",
    dim="time",
    lags=1,
)

local_ar

### Estimate covariance matrix
For the covariance matrix of the white noise we first estimate the empirical covariance matrix of the gridcell's values and then localize it using the Gaspari-Cohn function. This function goes to 0 for for larger distances  and becomes exactly 0 for distances twice the so called localisation radius. This is also called regularization. It ensures that grid points that are further away from each other do not correlate. Such spurious correlations can arise from rank deficient covariance matrices. In our case because we estimate the covariance on data that has more gridcells than timesteps.

The localisation radius is a parameter that needs to be calibrated and we find the best localisation radius by cross-validation of several radii using the negative loglikelihood.

1. Prepare the distance matrix - the distance between the gridpoints in km.

In [None]:
grid_stacked = resids["historical"].ds[["lat", "lon"]]
geodist = mesmer.geospatial.geodist_exact(grid_stacked.lon, grid_stacked.lat)

# plot
f, ax = plt.subplots()
geodist.plot(ax=ax, cmap="Blues")

ax.set_aspect("equal")

2. prepare the localizer(s) to regularize the covariance matrix

In [None]:
phi_gc_localizer = mesmer.stats.gaspari_cohn_correlation_matrices(
    geodist, range(5_000, 15_001, 500)
)

# plot one
f, ax = plt.subplots()
phi_gc_localizer[5000].plot(ax=ax, cmap="Blues")

ax.set_aspect("equal")

3. Compute the weights

In [None]:
# reusing weights from local trend regression

4. find the best localization radius and localize the empirical covariance matrix

In [None]:
dim = "sample"
k_folds = 15

localized_ecov = mesmer.stats.find_localized_empirical_covariance(
    resids_pooled.residuals,
    weights_pooled.weights,
    phi_gc_localizer,
    dim,
    k_folds=k_folds,
)

localized_ecov

In [None]:
f, axs = plt.subplots(1, 2, sharey=True, constrained_layout=True)

opt = dict(vmin=0, vmax=1.5, cmap="Blues", add_colorbar=False)

ax = axs[0]
localized_ecov.covariance.plot(ax=ax, **opt)
ax.set_aspect("equal")
ax.set_title("Empirical covariance")

ax = axs[1]
localized_ecov.localized_covariance.plot(ax=ax, **opt)
ax.set_aspect("equal")
ax.set_title("Localized empirical covariance")
ax.set_ylabel("")
plt.show()

5. Adjust the regularized covariance matrix

   Lastly we need to adjust the localized covariance matrix using the AR(1) parameters since the variance of the time series we observe is bigger than the variance of the driving white noise process. Read more about this in: "Statistical Analysis in Climate Research" by Storch and Zwiers (1999, reprinted 2003).

In [None]:
localized_covariance_adjusted = mesmer.stats.adjust_covariance_ar1(
    localized_ecov.localized_covariance, local_ar.coeffs
)

## Saving the parameters

Finally, we have calibrated all needed parameters and can save them. We can use filefisher to nicely create file names and save the parameters.

In [None]:
# define path relative to this notebook & create folder
param_path = pathlib.Path("./output/calibrated_parameters/")

In [None]:
PARAM_FILEFINDER = filefisher.FileFinder(
    path_pattern=param_path / "{esm}_{scen}",
    file_pattern="params_{module}_{esm}_{scen}.nc",
)

scen_str = "-".join(scenarios)

folder = PARAM_FILEFINDER.create_path_name(esm=model, scen=scen_str)
pathlib.Path(folder).mkdir(exist_ok=True, parents=True)

params = {
    "volcanic": volcanic_params,
    "global-variability": global_ar_params,
    "local-trends": local_lin_reg,
    "local-variability": local_ar,
    "covariance": localized_ecov,
    "grid-orig": grid_orig,
}


save_files = False  # we don't save them here in the example
if save_files:

    for module, param in params.items():

        filename = PARAM_FILEFINDER.create_full_name(
            module=module,
            esm=model,
            scen=scen_str,
        )

        param.to_netcdf(filename)

When you want to use the calibrated parameters for emulation, see the Tutorials for emulating one or multiple scenarios in the Tutorial section next.