# MESMER Tutorial
This is a tutorial for the use of MESMER. Here we demonstrate how MESMER is able to produce Earth System Model-specific spatio-temporally correlated temperature field realizations, taking annual global mean near surface temperature as input and emulating annual local near surface land temperature. This tutorial shows how to calibrate the parameters for MESMER on an example dataset of coarse regridded ESM output and emulate new realisations

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import pathlib

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from filefisher import FileFinder

import mesmer

## Calibrating MESMER

We calibrate the parameters for MESMER using three scenarios: a historical, a low emission (SSP1-2.6), and a high emission (SSP5-8.5) scenario, each including several ensemble members. Training MESMER consists of four parts:

- **global trend**: compute the global temperature trend, including the volcanic influence on historical trends
- **global variablity**: estimating the parameters to generate global variability
- **local trend**: estimate parameters to translate global mean temperature (including global variability) into locally resolved temperature
- **local variability**: estimate parameters needed to generate local variability

You can find the basics of this approach in: Beusch, et al. ([2020](https://doi.org/10.5194/ESD-11-139-2020)): "*Emulating Earth system model temperatures with MESMER: From global mean temperature trajectories to grid-point-level realizations on land.*" Earth System Dynamics. 

### Load data

MEMSER expects a specific data format. Data from each scenario should be node (or group) on an `xr.DataTree` (more on this below) e.g.:

```
<xarray.DataTree>
Group: /
├── Group: /historical
|    ...
├── Group: /ssp126
|    ...
```

Each node should be `xr.Dataset` with 4 dimensions: member, time, lat, lon. Below we show one way to load data such that it conforms to the desired format. We load data from the cmip6-ng ("new generation") repository. This data has undergone a small reformatting from the original cmip6 archive. For the sake of computational speed we also load data which has been regridded to a coarse resolution. Loading the data can be adapted to the data format you are most used to - as long as the final output has the desired format.

---

MESMER is Earth System Model specific, aiming to reproduce to some extent the behaviour of one ESM. Here we train on the CMIP6 output of the model IPSL-CM6A-LR.

In [None]:
model = "IPSL-CM6A-LR"

We use the library [filefisher](https://github.com/mpytools/filefisher) to search all files in the cmip6-ng archive for the model and scenarios we want to use. Filefisher can search through paths for given certain file patterns. It returns all paths matching the pattern such that you can load the files in the next step.

Here, we want to find all files that have data for annual near surface temperature (tas) for the used model and the future scenarios ssp126 and ssp585. Next, we search for the historical data that match the members found for the two future scenarios.

In [None]:
cmip_path = mesmer.example_data.cmip6_ng_path(relative=True)

CMIP_FILEFINDER = FileFinder(
    path_pattern=cmip_path / "{variable}/{time_res}/{resolution}",
    file_pattern="{variable}_{time_res}_{model}_{scenario}_{member}_{resolution}.nc",
)

Search data for ssp126 and ssp585 - we find one and two ensemble members, respectively:

In [None]:
scenarios = ["ssp126", "ssp585"]

fc_scens = CMIP_FILEFINDER.find_files(
    variable="tas", scenario=scenarios, model=model, resolution="g025", time_res="ann"
)


fc_scens.df

We also need to find the same ensemble members in the historical data, such that we end with five files we need to load:

In [None]:
# get the historical members that are also in the future scenarios, but only once
members = fc_scens.df.member.unique()

fc_hist = CMIP_FILEFINDER.find_files(
    variable="tas",
    scenario="historical",
    model=model,
    resolution="g025",
    time_res="ann",
    member=members,
)

fc_all = fc_hist.concat(fc_scens)
fc_all.df

Now we load all the files we found into a ``DataTree``. ``DataTree`` is a data structure provided by [xarray](https://docs.xarray.dev/en/stable/index.html).

Essentially, ``DataTree`` is a container to hold xarray `Ddataset` with data variables that are not alignable. This is useful for us since we have historical and future data, which have different time coordinates. Moreover, the scenarios may also have different numbers of members (as e.g. ssp126, which only has one). Thus, we store the data of each sceanrio in a `xarray.Dataset` holding all its ensemble members along a `member` dimension. Then we store all the scenario datasets in one `DataTree`. The `DataTree` allows us to perform computations on each of the datasets in a readable way.

In [None]:
dt = xr.DataTree()

scenarios_incl_hist = ["historical"] + scenarios

time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)

# load data for each scenario
for scen in scenarios_incl_hist:
    files = fc_all.search(scenario=scen)

    # load all members for a scenario
    members = []
    for fN, meta in files.items():
        ds = xr.open_dataset(fN, decode_times=time_coder)
        # drop unnecessary variables
        ds = ds.drop_vars(["height", "time_bnds", "file_qf"], errors="ignore")
        # assign member-ID as coordinate
        ds = ds.assign_coords({"member": meta["member"]})
        members.append(ds)

    # create a Dataset that holds each member along the member dimension
    scen_data = xr.concat(members, dim="member")
    # put the scenario dataset into the DataTree
    dt[scen] = xr.DataTree(scen_data)

dt

This results in the data format discussed above. You can examine it by clicking on `Groups` above.

---

We will need some configuration parameters in the follwing:
1. ``THRESHOLD_LAND``: Threshold above which land fraction to consider a grid point as a land grid point.
2. ``REFERENCE_PERIOD``: We will work not with abolute temperature values but with temperature anomalies w.r.t. a reference period

In [None]:
THRESHOLD_LAND = 1 / 3
REFERENCE_PERIOD = slice("1850", "1900")

### Calculate anomalies

In [None]:
# calculate anomalies w.r.t. the reference period
tas_anom = mesmer.anomaly.calc_anomaly(dt, reference_period=REFERENCE_PERIOD)

### Volcanic contributions
The volcanic contributions to the global mean temperature trend of the historical period have to be removed in order to estimate the linear regression of global mean temperature to local temperature.

1. Calculate global mean and smooth the forcing data using a lowess smoother, using 50 time steps. 

In [None]:
# calculate global mean
tas_globmean = mesmer.weighted.global_mean(tas_anom)
tas_globmean

In [None]:
# mean over members before smoothing

tas_globmean_ensmean = tas_globmean.mean(dim="member")

n_steps = 50
tas_globmean_smoothed = mesmer.stats.lowess(
    tas_globmean_ensmean,
    dim="time",
    n_steps=n_steps,
    use_coords=False,
)

In [None]:
# plot historical

f, ax = plt.subplots()

h0, *_ = tas_globmean["historical"].tas.plot.line(ax=ax, x="time", color="grey", lw=1)
a2, *_ = tas_globmean_smoothed["historical"].tas.plot.line(ax=ax, x="time", lw=2)

ax.legend([h0, a2], ["Ensemble members", "Smooth ensemble mean"])

2. Fit the parameter of the volcanic contributions only on the historical smoothed data of all ensemble members available. The future scenarios do not have volcanic contributions.

In [None]:
hist_tas_residuals = tas_globmean["historical"] - tas_globmean_smoothed["historical"]

# fit volcanic influence
volcanic_params = mesmer.volc.fit_volcanic_influence(hist_tas_residuals.tas)

volcanic_params.aod

3. Superimpose the volcanic influence on the historical time series. Note that due to the approach to handle the historical data as its own scenario, we ecounter discontinuities at the boundary between historical and future period. However, this is not relevant for the fitting of the parameters hereafter.

In [None]:
# superimpose the volcanic forcing on historical data
tas_globmean_smoothed["historical"] = mesmer.volc.superimpose_volcanic_influence(
    tas_globmean_smoothed["historical"],
    volcanic_params,
)

In [None]:
# some plotting
f, ax = plt.subplots()

# plot unsmoothed global means
tas_globmean["historical"].tas.plot.line(
    ax=ax, lw=1, x="time", color="grey", add_legend=False
)
tas_globmean["ssp126"].tas.plot.line(
    ax=ax, lw=1, x="time", color="#6baed6", add_legend=False
)
tas_globmean["ssp585"].tas.plot.line(
    ax=ax, lw=1, x="time", color="pink", add_legend=False
)

# plot smoothed global means including volcanic influence for historical
tas_globmean_smoothed["historical"].tas.plot.line(
    ax=ax, lw=1, x="time", color="black", label="historical"
)
tas_globmean_smoothed["ssp126"].tas.plot.line(
    ax=ax, lw=1, x="time", color="blue", label="ssp126"
)
tas_globmean_smoothed["ssp585"].tas.plot.line(
    ax=ax, lw=1, x="time", color="red", label="ssp585"
)


histend = tas_globmean["historical"].time.isel(time=-1).item()
ax.axvline(histend, color="0.4")
ax.axhline(0, color="0.1", lw=0.5)

ax.set_title("")
plt.legend(loc="upper left")

4. Calculate residuals (w.r.t. smoothed ts) i.e. remove the smoothed global mean, including the volcanic influence from the anomalies.

In [None]:
tas_globmean_resids = tas_globmean - tas_globmean_smoothed

In [None]:
# plot residuals
h0, *_ = tas_globmean_resids["historical"].tas.plot.line(
    x="time", color="grey", lw=1, add_legend=False
)
h1, *_ = tas_globmean_resids["ssp126"].tas.plot.line(
    x="time", color="blue", lw=1, add_legend=False
)
h2, *_ = tas_globmean_resids["ssp585"].tas.plot.line(
    x="time", color="red", lw=1, add_legend=False
)


plt.title("Residuals")
plt.axhline(0, lw=1, color="0.1")
histend = tas_globmean_resids["historical"].time.isel(time=-1).item()
plt.axvline(histend, color="0.1", lw=2)

plt.legend([h0, h1, h2], ["historical", "ssp126", "ssp585"])
plt.show()
# TODO add legend for the scenarios

### Global variability

In this step we want to fit an AR process for estimating global variability, taking in the residual global mean temperature as follows:

$T_{t}^{glob, var} = \alpha_0 + \sum\limits_{k=1}^{k=p} \alpha_k \cdot T_{t-k}^{glob, var} + \varepsilon_t,\ \varepsilon_t \sim \mathcal{N}(0, \sigma)$

We first estimate the order of the AR process and then fit the parameters. Internally, we fit the parameters for each member and then average first over the parameters of each scenario and then over all scenarios to arrive at a single set of parameters.

In [None]:
ar_order = mesmer.stats.select_ar_order_scen_ens(
    tas_globmean_resids, dim="time", ens_dim="member", maxlag=12, ic="bic"
)

global_ar_params = mesmer.stats.fit_auto_regression_scen_ens(
    tas_globmean_resids, dim="time", ens_dim="member", lags=ar_order
)

global_ar_params.drop_vars("nobs")
# TODO: drop nobs internally

### Local forced response
Now we need to estimate how the global trend translates into a local forced response. This is done using a linear regression of the global trend and the global variability as predictors:

$T_{s,t}^{resp} = \beta_s^{trend} \cdot T_t^{glob, trend} + \beta_s^{int} + \beta_s^{var} \cdot T_t^{glob, var}$

To this end, we stack all values (members, scenarios) into a single dataset, the only important thing is that predictor and predicted values stay together.

We need to prepare the local temperature data:

1. Mask out ocean grid points (where the land fraction is larger than `THRESHOLD_LAND`)
2. Mask out Antarctica
3. Convert the data from a 2D lat-lon grid to a 1D grid by stacking it and removing all gridcells that were previously masked out.



Before stacking, we extract the original grid. We need to save this together with the parameters to later be able to reconstruct the original grid from the gridpoints.

In [None]:
# extract original grid
grid_orig = tas_anom["historical"].ds[["lat", "lon"]]

In [None]:
def mask_and_stack(dt, threshold_land):
    dt = mesmer.mask.mask_ocean_fraction(dt, threshold_land)
    dt = mesmer.mask.mask_antarctica(dt)
    dt = mesmer.grid.stack_lat_lon(dt)
    return dt


# mask and stack the data
tas_stacked = mask_and_stack(tas_anom, THRESHOLD_LAND)

In [None]:
tas_stacked["ssp585"].tas.isel(member=1).plot()

We have now converted the 3D field (with dimensions lat, lon, and time) to a 2D field (with dimensions gridcell and time).

In [None]:
# create weights
weights = mesmer.weighted.equal_scenario_weights_from_datatree(tas_stacked)

In [None]:
weights

We create a new `DataTree` from all predictors - here the smoothed global mean and it's residuals. We could add more predictors similarly (e.g. the squared temperatures or the ocean heat uptake):

In [None]:
predictors = xr.DataTree.from_dict(
    {"tas": tas_globmean_smoothed, "tas_resids": tas_globmean_resids}
)
# predictors["tas2"] = tas_globmean_smoothed ** 2

target = tas_stacked.copy()

To estimated the forced response use a linear regression. For this we need to stack the different scenarios and ensemble members for the predictors, target, and weights:

In [None]:
# TODO: update after #633
predictors_stacked, target_stacked, weights_stacked = (
    mesmer.datatree.stack_datatrees_for_linear_regression(
        predictors, target, weights, stacking_dims=["member", "time"]
    )
)

target_stacked

Thus insead now the `DataTree` over the scenario, and the time and ensemble dimension are stacked to the sample dimension. We can now fit these using a linear regression:

In [None]:
local_lin_reg = mesmer.stats.LinearRegression()

local_lin_reg.fit(
    predictors=predictors_stacked,
    target=target_stacked.tas,
    dim="sample",
    weights=weights_stacked.weights,
)

local_forced_response_params = local_lin_reg.params
local_forced_response_params

In [None]:
data_vars = (
    "intercept",
    "tas",
    "tas_resids",
)

f, axs = plt.subplots(
    3, 1, sharex=True, sharey=True, subplot_kw={"projection": ccrs.PlateCarree()}
)
axs = axs.flatten()

for ax, data_var in zip(axs, data_vars):

    da = local_forced_response_params[data_var]
    da = mesmer.grid.unstack_lat_lon_and_align(da, grid_orig)

    h = da.plot(
        ax=ax, label=data_var, robust=True, center=0, extend="both", add_colorbar=False
    )

    ax.set_extent((-180, 180, -60, 85), ccrs.PlateCarree())
    cbar = plt.colorbar(h, ax=ax, extend="both", pad=0.025)  # , shrink=0.7)
    ax.set(title=data_var, xlabel="", ylabel="", xticks=[], yticks=[])
    ax.coastlines()

### Local variability
Now we need to fit the parameters for the AR(1) process with a spatially correlated noise term used to emulate local variability:

$\eta_{s,t} = \gamma_{0, s} + \gamma_{1, 2} \cdot \eta_{s, t-1} + \nu_{s,t}, \ \nu_{s,t} \sim \mathcal{N}(0, \Sigma_{\nu}(r))$

The first component that contains the AR parameters ensures temporal correlation of the local variability whereas the noise term ensures spatial consistency. The covariance matrix is estimated on the whole grid and represents the spatial correlation of temperatures at the different gridpoints.

#### Estimate the AR parameters

First we need to compute the residuals after the linear regression:

In [None]:
tas_stacked_residuals = local_lin_reg.residuals(
    predictors=predictors_stacked, target=target_stacked.tas
).T

tas_stacked_residuals.plot()

# unstack the residuals
tas_un_stacked_residuals = tas_stacked_residuals.set_index(
    sample=("time", "member", "scenario")
).unstack("sample")

In [None]:
# put each scenario into a datatree again, get rid of superfluous time steps and/or members
dt_resids = xr.DataTree()
for scenario in tas_un_stacked_residuals.scenario.values:
    dt_resids[scenario] = xr.DataTree(
        tas_un_stacked_residuals.to_dataset()
        .sel(scenario=scenario)
        .dropna("member", how="all")
        .dropna("time")
        .drop_vars("scenario")
    )

In [None]:
local_ar = mesmer.stats.fit_auto_regression_scen_ens(
    dt_resids,
    ens_dim="member",
    dim="time",
    lags=1,
)

local_ar

#### Estimate covariance matrix
For the covariance matrix of the white noise we first estimate the empirical covariance matrix of the gridcell's values and then localize it using the Gaspari-Cohn function. This function goes to 0 for for distances bigger than the so called localisation radius. This is also called regularization. It ensures that grid points that are further away from each other do not correlate. Such spurious correlations can arise from rank deficient covariance matrices. In our case because we estimate the covariance on data that has more gridcells than timesteps.

The localisation radius is a parameter that needs to be calibrated and we find the best localisation radius by cross-validation of several radii using the negative loglikelihood.

1. Prepare the distance matrix - the distance between the gridpoints in km.

In [None]:
dt = dt_resids["historical"].ds
geodist = mesmer.geospatial.geodist_exact(dt.lon, dt.lat)

# plot
f, ax = plt.subplots()
geodist.plot(ax=ax, cmap="Blues")

ax.set_aspect("equal")

2. prepare the localizer(s) to regularize the covariance matrix

In [None]:
phi_gc_localizer = mesmer.stats.gaspari_cohn_correlation_matrices(
    geodist, range(5_000, 15_001, 500)
)

# plot one
f, ax = plt.subplots()
phi_gc_localizer[5000].plot(ax=ax, cmap="Blues")

ax.set_aspect("equal")

3. Compute the weights

In [None]:
# reusing weights from local trend regression

4. find the best localization radius and localize the empirical covariance matrix

In [None]:
dim = "sample"
k_folds = 15

localized_ecov = mesmer.stats.find_localized_empirical_covariance(
    tas_stacked_residuals, weights_stacked.weights, phi_gc_localizer, dim, k_folds
)

localized_ecov

In [None]:
f, axs = plt.subplots(1, 2, sharey=True, constrained_layout=True)

opt = dict(vmin=0, vmax=1.5, cmap="Blues", add_colorbar=False)

ax = axs[0]
localized_ecov.covariance.plot(ax=ax, **opt)
ax.set_aspect("equal")
ax.set_title("Empirical covariance")

ax = axs[1]
localized_ecov.localized_covariance.plot(ax=ax, **opt)
ax.set_aspect("equal")
ax.set_title("Localized empirical covariance")
ax.set_ylabel("")
plt.show()

5. Adjust the regularized covariance matrix

Lastly we need to adjust the localized covariance matrix using the AR(1) parameters since the variance of the time series we observe is bigger than the variance of the driving white noise process. Read more about this here: "Statistical Analysis in Climate Research" by Stroch and Zwiers (1999, reprinted 2003).

In [None]:
localized_covariance_adjusted = mesmer.stats.adjust_covariance_ar1(
    localized_ecov.localized_covariance, local_ar.coeffs
)

### Saving the parameters

Finally, we have calibrated all needed parameters and can save them. We can use filefisher to nicely create file names and save the parameters.

In [None]:
# define path relative to this notebook & create folder
param_path = pathlib.Path("./output/tas/multi_scen_multi_ens")
param_path.mkdir(exist_ok=True, parents=True)

In [None]:
PARAM_FILEFINDER = FileFinder(
    path_pattern=param_path / "test-params/{module}/",
    file_pattern="params_{module}_{esm}_{scen}.nc",
)

scen_str = "-".join(scenarios)

volcanic_file = PARAM_FILEFINDER.create_full_name(
    module="volcanic",
    esm=model,
    scen=scen_str,
)
global_ar_file = PARAM_FILEFINDER.create_full_name(
    module="global-variability",
    esm=model,
    scen=scen_str,
)
local_forced_file = PARAM_FILEFINDER.create_full_name(
    module="local-trends",
    esm=model,
    scen=scen_str,
)
local_ar_file = PARAM_FILEFINDER.create_full_name(
    module="local-variability",
    esm=model,
    scen=scen_str,
)
localized_ecov_file = PARAM_FILEFINDER.create_full_name(
    module="covariance",
    esm=model,
    scen=scen_str,
)

save_files = False  # we don't save them here in the example
if save_files:
    # save the parameters
    volcanic_params.to_netcdf(volcanic_file)
    global_ar_params.to_netcdf(global_ar_file)
    local_lin_reg.to_netcdf(local_forced_file)
    local_ar.to_netcdf(local_ar_file)
    localized_ecov.to_netcdf(localized_ecov_file)

# TODO: save the original grid

We clear everything here

In [None]:
%reset -f

## Emulating near surface temperature on land

In [None]:
import importlib

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from filefisher import FileFinder

import mesmer
from mesmer.core._datatreecompat import map_over_datasets

### Load forcing data
Now one can use any global mean temperature trajectory to draw gridded realisations. For this example we want to create emulations for ssp126 and ssp585 to compare the emulations to the actual ESM output. Here we concatenate historical and future runs to create a continuous timeseries.

In [None]:
model = "IPSL-CM6A-LR"
scenarios = ["ssp126", "ssp585"]

# some configuration parameters
THRESHOLD_LAND = 1 / 3

REFERENCE_PERIOD = slice("1850", "1900")

HIST_PERIOD = slice("1850", "2014")

In [None]:
# find the files - use same relative path as above
param_path = pathlib.Path("./output/tas/multi_scen_multi_ens")

CMIP_FILEFINDER = FileFinder(
    path_pattern=cmip_path / "{variable}/{time_res}/{resolution}",
    file_pattern="{variable}_{time_res}_{model}_{scenario}_{member}_{resolution}.nc",
)

fc_scens = CMIP_FILEFINDER.find_files(
    variable="tas", scenario=scenarios, model=model, resolution="g025", time_res="ann"
)

members = fc_scens.df.member.unique()

fc_hist = CMIP_FILEFINDER.find_files(
    variable="tas",
    scenario="historical",
    model=model,
    resolution="g025",
    time_res="ann",
    member=members,
)

In [None]:
def _get_hist(meta, fc_hist):

    meta_hist = meta | {"scenario": "historical"}

    fc = fc_hist.search(**meta_hist)

    if len(fc) == 0:
        raise FileNotFoundError("no hist file found")
    if len(fc) != 1:
        raise ValueError("more than one hist file found")

    fN, meta_hist = fc[0]

    return fN, meta_hist


def load_hist(meta, fc_hist):
    fN, __ = _get_hist(meta, fc_hist)
    return xr.open_dataset(fN, use_cftime=True)


def load_hist_scen_continuous(fc_hist, fc_scens):
    dt = xr.DataTree()
    for scen in fc_scens.df.scenario.unique():
        files = fc_scens.search(scenario=scen)

        members = []

        for fN, meta in files.items():

            try:
                hist = load_hist(meta, fc_hist)
            except FileNotFoundError:
                continue

            proj = xr.open_dataset(fN, use_cftime=True)

            ds = xr.combine_by_coords(
                [hist, proj],
                combine_attrs="override",
                data_vars="minimal",
                compat="override",
                coords="minimal",
            )

            ds = ds.drop_vars(("height", "time_bnds", "file_qf"), errors="ignore")

            ds = mesmer.grid.wrap_to_180(ds)

            # assign member-ID as coordinate
            ds = ds.assign_coords({"member": meta["member"]})

            members.append(ds)

        # create a Dataset that holds each member along the member dimension
        scen_data = xr.concat(members, dim="member")
        # put the scenario dataset into the DataTree
        dt[scen] = xr.DataTree(scen_data)
    return dt

In [None]:
tas = load_hist_scen_continuous(fc_hist, fc_scens)
ref = tas.sel(time=REFERENCE_PERIOD).mean("time")
tas_anom = tas - ref
tas_globmean = mesmer.weighted.global_mean(tas_anom)

tas_globmean_ensmean = tas_globmean.mean(dim="member")
tas_globmean_forcing = mesmer.stats.lowess(
    tas_globmean_ensmean,
    dim="time",
    n_steps=30,
    use_coords=False,
)
time = tas_globmean_forcing["ssp126"].time

### Load the parameters

In [None]:
param_path = data_path / "output" / "tas" / "multi_scen_multi_ens"

PARAM_FILEFINDER = FileFinder(
    path_pattern=param_path / "test-params/{module}/",
    file_pattern="params_{module}_{esm}_{scen}.nc",
)
scen_str = "-".join(scenarios)

In [None]:
all_modules = [
    "volcanic",
    "global-variability",
    "local-trends",
    "local-variability",
    "covariance",
]
param_files = PARAM_FILEFINDER.find_files(module=all_modules, esm=model, scen=scen_str)

params = xr.DataTree()

for module in all_modules:
    params[module] = xr.DataTree(
        xr.open_dataset(param_files.search(module=module).paths.pop()), name=module
    )

### Define seeds for global and local variability 
If we want reproducible results we need to set a seed for the random samples of global and local variability. Here, we set the seed to a chosen number, but for automated generation of seeds i.e. for several ESM we recommend using the `secrets` from the standard library. 
Then you would generate a seed using:

```python
import secrets

xr.Dataset(data_vars={"seed": secrets.randbits(64)})
```

In [None]:
seed_global_variability = xr.DataTree.from_dict(
    {
        "ssp126": xr.Dataset(data_vars={"seed": 981}),
        "ssp585": xr.Dataset(data_vars={"seed": 314}),
    }
)
seed_local_variability = xr.DataTree.from_dict(
    {
        "ssp126": xr.Dataset(data_vars={"seed": 272}),
        "ssp585": xr.Dataset(data_vars={"seed": 42}),
    }
)

### Make emulations

In [None]:
# some settings
n_realisations = 10

buffer_global_variability = 50
buffer_local_variability = 20

#### 1. Adding the volcanic influence to the smooth global mean forcing
This is optional, depending on if you want to reproduce the past accurately. This is necessary when we want to evaluate the performance of our emulator on ESM or observation data but might not be necessary for more abstract research questions.

In [None]:
tas_globmean_forcing = mesmer.volc.superimpose_volcanic_influence(
    tas_globmean_forcing,
    params["volcanic"].ds,
    hist_period=HIST_PERIOD,
)

In [None]:
tas_globmean_forcing["ssp126"].to_dataset().tas.plot()
tas_globmean_forcing["ssp585"].to_dataset().tas.plot()

#### 2. Compute global variabilty 
Draw samples from a AR process with the calibrated parameters.

In [None]:
global_variability = mesmer.stats.draw_auto_regression_uncorrelated(
    params["global-variability"].ds,
    realisation=n_realisations,
    time=time,
    seed=seed_global_variability,
    buffer=buffer_global_variability,
)
global_variability = map_over_datasets(
    lambda ds: ds.rename({"samples": "tas"}), global_variability
)

#### 3. Compute local forced response
Apply linear regression using the global mean forcing and the global variability as predictors. Optionally, you can also add other variables to the predictors like ocean heat content or squared global mean temperature.

In [None]:
predictors = xr.DataTree.from_dict(
    {
        "ssp126": xr.DataTree.from_dict(
            {
                "tas": tas_globmean_forcing["ssp126"],
                "tas_resids": global_variability["ssp126"],
            }
        ),
        "ssp585": xr.DataTree.from_dict(
            {
                "tas": tas_globmean_forcing["ssp585"],
                "tas_resids": global_variability["ssp585"],
            }
        ),
    }
)

lr = mesmer.stats.LinearRegression()
lr.params = params["local-trends"].ds

# uses ``exclude`` to split the linear response
local_forced_response = xr.DataTree()
local_variability_from_global_var = xr.DataTree()

for scen in predictors.children:
    # local variability part driven by global mean
    local_forced_response[scen] = xr.DataTree(
        lr.predict(predictors[scen], exclude={"tas_resids"}).rename("tas").to_dataset()
    )

    # local variability part driven by global variabilty
    local_variability_from_global_var[scen] = xr.DataTree(
        lr.predict(predictors[scen], exclude={"tas", "intercept"})
        .rename("tas")
        .to_dataset()
    )

#### 4. Compute local variability
We compute the local variability by applying an AR(1) process to ensure consistency in time and adding spatially correlated innovations at each time step to get spatially coherent random samples at each gridpoint.

In [None]:
local_variability = mesmer.stats.draw_auto_regression_correlated(
    params["local-variability"].ds,
    params["covariance"].localized_covariance,
    time=time,
    realisation=n_realisations,
    seed=seed_local_variability,
    buffer=buffer_local_variability,
)
local_variability = map_over_datasets(
    lambda ds: ds.rename({"samples": "tas"}), local_variability
)

#### 5. Add everything together

In [None]:
local_variability_total = local_variability_from_global_var + local_variability
emulations = local_forced_response + local_variability_total

### Saving emulations
We recommend saving the emulations together with the seeds used for emulating.

In [None]:
for scen in emulations:
    local_seed = seed_local_variability[scen].seed.rename("seed_local_variability")
    global_seed = seed_global_variability[scen].seed.rename("seed_global_variability")
    emulations[scen] = xr.DataTree(
        xr.merge([emulations[scen].ds, local_seed, global_seed])
    )

emu_path = data_path / "output" / "tas" / "multi_scen_multi_ens"
# emulations.to_netcdf(emu_path / f"emulations_{model}_ssp126-ssp585.nc")

## Some example plots

In [None]:
grid_orig = tas_anom["ssp126"].ds[["lat", "lon"]]
spatial_emu_126 = mesmer.grid.unstack_lat_lon_and_align(
    emulations["ssp126"].tas, grid_orig
)
spatial_emu_585 = mesmer.grid.unstack_lat_lon_and_align(
    emulations["ssp585"].tas, grid_orig
)

f, axs = plt.subplots(3, 1, subplot_kw={"projection": ccrs.Robinson()})

opt = dict(cmap="Reds", transform=ccrs.PlateCarree(), vmin=0, vmax=15, extend="max")
spatial_emu_126.mean("realisation").sel(time="2100").plot(ax=axs[0], **opt)
spatial_emu_585.mean("realisation").sel(time="2100").plot(ax=axs[1], **opt)

diff = spatial_emu_585 - spatial_emu_126
diff.mean("realisation").sel(time="2100").plot(
    ax=axs[2], cmap="RdBu_r", transform=ccrs.PlateCarree(), center=0
)

axs[0].set_title("ssp126 2100")
axs[1].set_title("ssp585 2100")
axs[2].set_title("Difference")

for ax in axs:
    ax.coastlines()
    ax.set_global()

In [None]:
# plot global means
globmean_126 = mesmer.weighted.global_mean(spatial_emu_126)
globmean_585 = mesmer.weighted.global_mean(spatial_emu_585)

globmean_126_smoothed = mesmer.stats.lowess(
    globmean_126.mean("realisation"), dim="time", n_steps=50, use_coords=False
)
globmean_585_smoothed = mesmer.stats.lowess(
    globmean_585.mean("realisation"), dim="time", n_steps=50, use_coords=False
)

f, ax = plt.subplots()
globmean_585.plot.line(x="time", ax=ax, add_legend=False, color="lightblue")
globmean_126.plot.line(x="time", ax=ax, add_legend=False, color="pink")

globmean_585_smoothed.plot.line(x="time", ax=ax, color="blue", label="ssp585")
globmean_126_smoothed.plot.line(x="time", ax=ax, color="red", label="ssp126")

plt.legend()
plt.show()

In [None]:
esm_ssp585 = tas_anom["ssp585"].ds


def mask(ds, threshold_land):
    ds = mesmer.mask.mask_ocean_fraction(ds, threshold_land)
    ds = mesmer.mask.mask_antarctica(ds)
    return ds


esm_ssp585 = mask(esm_ssp585, THRESHOLD_LAND)

In [None]:
esm_ssp585 = esm_ssp585.tas.stack(sample=("time", "lat", "lon", "member"))
emu_ssp585 = spatial_emu_585.stack(sample=("time", "lat", "lon", "realisation"))

In [None]:
import statsmodels.api as sm

sm.qqplot_2samples(esm_ssp585, emu_ssp585, line="45")
plt.xlabel("ESM")
plt.ylabel("Emulation")
plt.show()