# Example execution of MESMER-M workflow for multiple scenarios and ensemble members
Training and emulation of monthly local temperature from yearly local temperature. We use an example data set on a coarse (20° x 20°) grid.

Import libraries and check MESMER version:

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import importlib

import filefisher
import pandas
import xarray as xr
from filefinder import FileContainer

import mesmer
from mesmer.core._datatreecompat import map_over_datasets

mesmer.__version__

## Calibrate emulator

### Configuration

In [None]:
LOCALISATION_RADII = list(range(1250, 6251, 250)) + list(range(6500, 8501, 500))
THRESHOLD_LAND = 1 / 3
REF_PERIOD = slice("1850", "1900")

In [None]:
# define paths of the example data

model = "IPSL-CM6A-LR"
scenarios = ["ssp585", "ssp126"]

TEST_DATA_PATH = importlib.resources.files("mesmer").parent / "tests" / "test-data"
cmip6_data_path = TEST_DATA_PATH / "calibrate-coarse-grid" / "cmip6-ng"

path_tas_mon = cmip6_data_path / "tas" / "mon" / "g025"
fN_hist_mon = path_tas_mon / f"tas_mon_{model}_historical_r1i1p1f1_g025.nc"
fN_proj_mon = path_tas_mon / f"tas_mon_{model}_ssp585_r1i1p1f1_g025.nc"

### Load Data for training the emulator

In [None]:
CMIP_FILEFINDER = filefisher.FileFinder(
    path_pattern=cmip6_data_path / "{variable}/{time_res}/{resolution}",
    file_pattern="{variable}_{time_res}_{model}_{scenario}_{member}_{resolution}.nc",
)

In [None]:
fc_scens_y = CMIP_FILEFINDER.find_files(
    variable="tas", scenario=scenarios, model=model, resolution="g025", time_res="ann"
)

# only get the historical members that are also in the future scenarios, but only once
unique_scen_members_y = fc_scens_y.df.member.unique()

fc_hist_y = CMIP_FILEFINDER.find_files(
    variable="tas",
    scenario="historical",
    model=model,
    resolution="g025",
    time_res="ann",
    member=unique_scen_members_y,
)

fc_all_y = FileContainer(pandas.concat([fc_hist_y.df, fc_scens_y.df]))
fc_all_y.df

In [None]:
fc_scens_m = CMIP_FILEFINDER.find_files(
    variable="tas", scenario=scenarios, model=model, resolution="g025", time_res="mon"
)

# only get the historical members that are also in the future scenarios, but only once
unique_scen_members_m = fc_scens_y.df.member.unique()

fc_hist_m = CMIP_FILEFINDER.find_files(
    variable="tas",
    scenario="historical",
    model=model,
    resolution="g025",
    time_res="mon",
    member=unique_scen_members_m,
)

fc_all_m = FileContainer(pandas.concat([fc_hist_m.df, fc_scens_m.df]))
fc_all_m.df

In [None]:
tas_y = xr.DataTree()

scenarios_whist = scenarios.copy()
scenarios_whist.append("historical")

# load data for each scenario
for scen in scenarios_whist:
    files = fc_all_y.search(scenario=scen)

    # load all members for a scenario
    members = []
    for fN, meta in files:
        ds = xr.open_dataset(fN, use_cftime=True)
        # drop unnecessary variables
        ds = ds.drop_vars(["height", "time_bnds", "file_qf"], errors="ignore")
        # assign member-ID as coordinate
        ds = ds.assign_coords({"member": meta["member"]})
        members.append(ds)

    # create a Dataset that holds each member along the member dimension
    scen_data = xr.concat(members, dim="member")
    # put the scenario dataset into the DataTree
    tas_y[f"{scen}"] = xr.DataTree(scen_data)

tas_y

In [None]:
tas_m = xr.DataTree()

scenarios_whist = scenarios.copy()
scenarios_whist.append("historical")

# load data for each scenario
for scen in scenarios_whist:
    files = fc_all_m.search(scenario=scen)

    # load all members for a scenario
    members = []
    for fN, meta in files:
        ds = xr.open_dataset(fN, use_cftime=True)
        # drop unnecessary variables
        ds = ds.drop_vars(["height", "time_bnds", "file_qf"], errors="ignore")
        # assign member-ID as coordinate
        ds = ds.assign_coords({"member": meta["member"]})
        members.append(ds)

    # create a Dataset that holds each member along the member dimension
    scen_data = xr.concat(members, dim="member")
    # put the scenario dataset into the DataTree
    tas_m[f"{scen}"] = xr.DataTree(scen_data)

tas_m

### Preprocessing

Calculate anomalies w.r.t the reference period

In [None]:
ref_y = mesmer.anomaly.calc_anomaly(tas_y, reference_period=REF_PERIOD)
ref_m = mesmer.anomaly.calc_anomaly(tas_m, reference_period=REF_PERIOD)

We only use land grid points and exclude Antarctica. The 3D data with dimensions `('time', 'lat', 'lon')` is stacked to 2D data with dimensions `('time', 'gridcell')`:

In [None]:
def mask_and_stack(ds, threshold_land):
    ds = mesmer.mask.mask_ocean_fraction(ds, threshold_land)
    ds = mesmer.mask.mask_antarctica(ds)
    ds = mesmer.grid.stack_lat_lon(ds)
    # ds = ds.stack(sample = ("member", "time"))
    return ds

In [None]:
tas_stacked_y = map_over_datasets(mask_and_stack, tas_y, kwargs={"threshold_land": THRESHOLD_LAND})
tas_stacked_m = map_over_datasets(mask_and_stack, tas_m, kwargs={"threshold_land": THRESHOLD_LAND})

In [None]:
tas_stacked_y["ssp585"].tas.isel(gridcell=0).plot(x="time")

### Fit the harmonic model

Fit the seasonal cycle with a harmonic model which can vary with local annual mean temperature
(fourier regression). Removes annual mean and, determines the optimal order and the coefficients
of the harmonic model

In [None]:
def extract_da_and_call_func(func, *names):
    # TODO: find a better solution for this

    def inner(*args, **kwargs):

        assert len(names) == len(args)

        args = (arg[name] if name is not None else arg for name, arg in zip(names, args))

        out = func(*args, **kwargs)
        return out

    return inner

In [None]:
harmonic_model_fit = map_over_datasets(
    extract_da_and_call_func(mesmer.stats.fit_harmonic_model, "tas", "tas"),
    tas_stacked_y,
    tas_stacked_m
)
harmonic_model_fit

In [None]:
def _avg_for_dtype(ds, dim):
    def avg_da(da, dim):
        if da.dtype == int:
            return da.quantile(q=0.5, dim=dim, method="nearest", skipna=True)
        else:
            return da.mean(dim=dim, skipna=True)

    return ds.map(avg_da, dim=dim)


def _avg_ens_then_scen(dt, ens_dim="member"):
    ens_mean = map_over_datasets(_avg_for_dtype, dt, kwargs={"dim": ens_dim})
    ds_ens_mean = mesmer.datatree.collapse_datatree_into_dataset(
        ens_mean, dim="scenario"
    )
    scen_mean = _avg_for_dtype(ds_ens_mean, dim="scenario")
    return scen_mean

In [None]:
# average over ensemble members and scenarios
# do not average predictions, drop time dim (is only present on predicitions) to avoid nans
harmonic_model_fit_wo_preds = harmonic_model_fit
#.drop_vars(("predictions", "time"))
harmonic_model_fit_scen_mean = _avg_ens_then_scen(
    harmonic_model_fit_wo_preds
)
#.drop_vars("quantile")
harmonic_model_fit_scen_mean

### Train the power transformer

The residuals are not necessarily symmetric - make them more normal using a Yeo-Johnson
transformation. The parameter $\lambda$ is modelled with a logistic regression using
local annual mean temperature as covariate.

In [None]:
# harmonic_model_predictions = harmonic_model
resids_after_hm = map_over_datasets(lambda m_dat, hm_dat: m_dat - hm_dat.predictions,
    tas_stacked_m, harmonic_model_fit
)
resids_after_hm

In [None]:
pt_coefficients = map_over_datasets(extract_da_and_call_func,
    mesmer.stats.fit_yeo_johnson_transform, tas_stacked_y, resids_after_hm
)
pt_coefficients

In [None]:
pt_coefficients_scen_mean = _avg_ens_then_scen(pt_coefficients)
pt_coefficients_scen_mean

In [None]:
def extract_das_and_call_yeo_johnson_transform(ds_y, ds_m, lambda_coeffs_ds):
    yearly_dat = ds_y.tas
    monthly_dat = ds_m.tas
    lambda_coeffs = lambda_coeffs_ds.lambda_coeffs
    return mesmer.stats.yeo_johnson_transform(yearly_dat, monthly_dat, lambda_coeffs)

In [None]:
transformed_hm_resids = map_over_datasets(extract_das_and_call_yeo_johnson_transform,
    tas_stacked_y,
    resids_after_hm,
    pt_coefficients,
)
transformed_hm_resids

### Fit cyclo-stationary AR(1) process

The monthly residuals are now assumed to follow a cyclo-stationary AR(1) process, where e.g. the July residuals depend on the ones from June and the ones of June on May's with distinct parameters.

In [None]:
def extract_da_and_fit_AR(ds_m, time_dim):
    monthly_dat = ds_m.transformed
    return mesmer.stats.fit_auto_regression_monthly(monthly_dat, time_dim=time_dim)

In [None]:
AR1_fit = map_over_datasets(extract_da_and_fit_AR,
    transformed_hm_resids, kwargs={"time_dim": "time"}
)
AR1_fit

In [None]:
AR1_fit_scen_mean = _avg_ens_then_scen(AR1_fit)
AR1_fit_scen_mean

### Find localized empirical covariance

Finally, we determine the localized empirical spatial covariance for each month separately:

In [None]:
geodist = mesmer.geospatial.geodist_exact(
    tas_stacked_y.historical.lon, tas_stacked_y.historical.lat
)

phi_gc_localizer = mesmer.stats.gaspari_cohn_correlation_matrices(
    geodist, localisation_radii=LOCALISATION_RADII
)

In [None]:
AR1_residuals = map_over_datasets(lambda ds: ds["residuals"], AR1_fit)
weights = mesmer.weighted.equal_scenario_weights_from_datatree(AR1_residuals)

AR1_residuals_ds = mesmer.datatree.collapse_datatree_into_dataset(
    AR1_residuals, dim="scenario"
)
weights_ds = mesmer.datatree.collapse_datatree_into_dataset(weights, dim="scenario")

monthly_resids = AR1_residuals_ds.residuals.groupby("time.month")
monthly_weights = weights_ds.weights.groupby("time.month")

localized_ecov = []

for mon in range(1, 13):
    data = monthly_resids[mon]
    data = data.stack(sample=("scenario", "member", "time"), create_index=False)
    data = data.dropna(dim="sample")

    mon_weights = monthly_weights[mon]
    mon_weights = mon_weights.stack(
        sample=("scenario", "member", "time"), create_index=False
    )
    mon_weights = mon_weights.dropna(dim="sample")

    res = mesmer.stats.find_localized_empirical_covariance(
        data,
        mon_weights,
        phi_gc_localizer,
        dim="sample",
        k_folds=30,
    )
    localized_ecov.append(res)

month = xr.DataArray(range(1, 13), dims="month")
localized_ecov = xr.concat(localized_ecov, dim=month)

### Saving

### time coordinate
We need to get the original time coordinate to be able to validate our results later on. If it is not needed to align the final emulations with the original data, this can be omitted, the time coordinates can later be generated for example with 


```python
monthly_time = xr.cftime_range("1850-01-01", "2100-12-31", freq="MS", calendar="gregorian")
monthly_time = xr.DataArray(monthly_time, dims="time", coords={"time": monthly_time})
```

In [None]:
# extract and save time coordinate
hist_time = tas_stacked_m.historical.time
scen_time = tas_stacked_m.ssp585.time
m_time = xr.concat([hist_time, scen_time], dim="time")

# TODO
# save the parameters to a file
# harmonic_model_fit
# pt_coefficients
# AR1_fit
# localized_ecov
# m_time

## Make emulations

To generate emulations the workflow of the calibration is reversed, using the estimated parameters from above. Here, we use the same local annual mean temperatures to force the emulations, but temperatures from other models, scenarios, ensemble members or emulated annual local temperatures can be used as well.

In [None]:
# # Re-import necessary libraries
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import xarray as xr

### Configuration

In [None]:
# parameters
NR_EMUS = 10
BUFFER = 20
# REF_PERIOD = slice("1850", "1900")

#### Random number seed

The `seed` determines the initial state for the random number generator. To avoid generating the same noise for different models and scenarios different seeds are required for each individual paring. For reproducibility the seed needs to be the same for any subsequent draw of the same emulator. To avoid human chosen standard seeds (e.g. `0`, `1234`) its recommended to also randomly generate the seeds and save them for later, using

```python
import secrets
secrets.randbits(128)
```

In [None]:
# random but constant
SEED = 234361146192407661971285321853135632294

### Load data needed for emulations

In [None]:
# TODO
# load the parameters from a file
# in this example notebook we directly use the calibration from above

In [None]:
# TODO
# load yearly temperature
# in this example we are using the original yearly temperature for demonstration

### Preprocessing

In [None]:
# preprocess tas
# ref = tas_y.sel(time=REF_PERIOD).mean("time", keep_attrs=True)
# tas_y = tas_y - ref
# tas_stacked_y = mask_and_stack(tas_y, threshold_land=THRESHOLD_LAND)

In [None]:
# get the original grid for transforming back later
grid_orig = ref_y.to_dataset()[["lat", "lon"]]

### Generate emulations

In [None]:
yearly_predictor = xr.concat(
    [
        tas_stacked_y.historical.tas.isel(member=0),
        tas_stacked_y.ssp585.tas.isel(member=0),
    ],
    dim="time",
)

In [None]:
# generate monthly data with harmonic model
monthly_harmonic_emu = mesmer.stats.predict_harmonic_model(
    yearly_predictor, harmonic_model_fit_scen_mean.coeffs, m_time
)

# generate variability around 0 with AR(1) model
local_variability_transformed = mesmer.stats.draw_auto_regression_monthly(
    AR1_fit_scen_mean,
    localized_ecov.localized_covariance,
    time=m_time,
    n_realisations=NR_EMUS,
    seed=SEED,
    buffer=BUFFER,
)

# invert the power transformation
local_variability_inverted = mesmer.stats.inverse_yeo_johnson_transform(
    yearly_predictor,
    local_variability_transformed,
    pt_coefficients_scen_mean.lambda_coeffs,
)

# add the local variability to the monthly harmonic
emulations = monthly_harmonic_emu + local_variability_inverted.inverted

In [None]:
# unstack to original grid
emulations_unstacked = mesmer.grid.unstack_lat_lon_and_align(emulations, grid_orig)

In [None]:
emulations_unstacked.isel(realisation=0, time=3011).plot()

### Saving and/or Analysis

In [None]:
# TODO
# save