# Example execution of MESMER-M workflow
Training and emulation of monthly local temperature from yearly local temperature.

# Setup

In [18]:
import datetime as dt
import importlib

import xarray as xr

import mesmer

In [19]:
# some parameters
model = "IPSL-CM6A-LR"
TEST_DATA_PATH = importlib.resources.files("mesmer").parent / "tests" / "test-data"
cmip6_data_path = TEST_DATA_PATH / "calibrate-coarse-grid" / "cmip6-ng"

path_tas_ann = cmip6_data_path / "tas" / "ann" / "g025"
fN_hist_ann = path_tas_ann / f"tas_ann_{model}_historical_r1i1p1f1_g025.nc"
fN_proj_ann = path_tas_ann / f"tas_ann_{model}_ssp585_r1i1p1f1_g025.nc"

path_tas_mon = cmip6_data_path / "tas" / "mon" / "g025"
fN_hist_mon = path_tas_mon / f"tas_mon_{model}_historical_r1i1p1f1_g025.nc"
fN_proj_mon = path_tas_mon / f"tas_mon_{model}_ssp585_r1i1p1f1_g025.nc"

LOCALISATION_RADII = list(range(1250, 6251, 250)) + list(range(6500, 8501, 500))
THRESHOLD_LAND = 1 / 3
ref_period = slice("1850", "1900")

## Load Data for training the emulator

In [20]:
# yearly temperature
tas_y = xr.open_mfdataset(
    [fN_hist_ann, fN_proj_ann],
    combine="by_coords",
    use_cftime=True,
    combine_attrs="override",
    data_vars="minimal",
    compat="override",
    coords="minimal",
    drop_variables=["height", "file_qf"],
).load()

# monthly temperature    
tas_m = xr.open_mfdataset(
    [fN_hist_mon, fN_proj_mon],
    combine="by_coords",
    use_cftime=True,
    combine_attrs="override",
    data_vars="minimal",
    compat="override",
    coords="minimal",
    drop_variables=["height", "file_qf"],
).load()

## Preprocessing

In [21]:
ref_y = tas_y.sel(time=ref_period).mean("time", keep_attrs=True)
ref_m = tas_m.sel(time=ref_period).mean("time", keep_attrs=True)

tas_y = tas_y - ref_y
tas_m = tas_m - ref_m

In [22]:
def mask_and_stack(ds, threshold_land):
    ds = mesmer.mask.mask_ocean_fraction(ds, threshold_land)
    ds = mesmer.mask.mask_antarctica(ds)
    ds = mesmer.grid.stack_lat_lon(ds)
    return ds

In [23]:
tas_stacked_y = mask_and_stack(tas_y, threshold_land=THRESHOLD_LAND)
tas_stacked_m = mask_and_stack(tas_m, threshold_land=THRESHOLD_LAND)

## Fit the hamronic model

In [24]:
harmonic_model_fit = mesmer.stats.fit_harmonic_model(
    tas_stacked_y.tas, tas_stacked_m.tas
)

## Train the power transformer

In [25]:
resids_after_hm = tas_stacked_m - harmonic_model_fit.predictions
pt_coefficients = mesmer.stats.fit_yeo_johnson_transform(
    resids_after_hm.tas, tas_stacked_y.tas
)
transformed_hm_resids = mesmer.stats.yeo_johnson_transform(
    resids_after_hm.tas, pt_coefficients, tas_stacked_y.tas
)

## Fit cyclo-stationary AR(1) process

In [26]:
# fit AR(1) parameters
AR1_fit = mesmer.stats.fit_auto_regression_monthly(
    transformed_hm_resids.transformed, time_dim="time"
)

In [27]:
# work out covariance matrix
geodist = mesmer.geospatial.geodist_exact(tas_stacked_y.lon, tas_stacked_y.lat)

phi_gc_localizer = mesmer.stats.gaspari_cohn_correlation_matrices(
    geodist, localisation_radii=LOCALISATION_RADII
)

weights = xr.ones_like(AR1_fit.residuals.isel(gridcell=0))
weights.name = "weights"

localized_ecov = mesmer.stats.find_localized_empirical_covariance_monthly(
    AR1_fit.residuals, weights, phi_gc_localizer, "time", 30
)

In [28]:
# TODO
# save the parameters to a file

## Make emulations

In [29]:
# parameters
nr_emus = 10
buffer = 20
seed = 0
# ref_period = slice("1850", "1900")

In [30]:
# TODO 
# load the parameters from a file
# here we use the data from above

In [31]:
# TODO 
# load yearly temperature
# here we are using the original yearly temperature for demonstration

In [32]:
# preprocess tas
# ref = tas_y.sel(time=ref_period).mean("time", keep_attrs=True)
# tas_y = tas_y - ref
# tas_stacked_y = mask_and_stack(tas_y, threshold_land=THRESHOLD_LAND)

In [33]:
# keep the original grid for transforming back later
grid_orig = ref_y[["lat", "lon"]]

In [34]:
# we need to get the original time coordinate to be able to validate our results
m_time = tas_stacked_m.time

In [35]:
# generate monthly data with harmonic model
monthly_harmonic_emu = mesmer.stats.predict_harmonic_model(
    tas_stacked_y.tas, harmonic_model_fit.coeffs, m_time
)

# generate variability around 0 with AR(1) model
local_variability_transformed = mesmer.stats.draw_auto_regression_monthly(
    AR1_fit,
    localized_ecov.localized_covariance,
    time=m_time,
    n_realisations=nr_emus,
    seed=seed,
    buffer=buffer,
)

# invert the power transformation
local_variability_inverted = mesmer.stats.inverse_yeo_johnson_transform(
    local_variability_transformed, pt_coefficients, tas_stacked_y.tas
)

# add the local variability to the monthly harmonic
emulations = monthly_harmonic_emu + local_variability_inverted.inverted

In [36]:
# unstack to original grid
emulations_unstacked = mesmer.grid.unstack_lat_lon_and_align(emulations, grid_orig)

In [37]:
# TODO
# save