You can run this notebook in a [live session](https://binder.pangeo.io/v2/gh/pangeo-data/climpred/main?urlpath=lab/tree/docs/source/bias_removal.ipynb) [<img src="https://mybinder.org/badge_logo.svg" alt='binder badge'>](https://binder.pangeo.io/v2/gh/pangeo-data/climpred/main?urlpath=lab/tree/docs/source/bias_removal.ipynb) or view it [on Github](https://github.com/pangeo-data/climpred/blob/main/docs/source/bias_removal.ipynb).

#### EGU live demo

some intro

- multi-model: vectorized with xarray, no loops

In [None]:
# linting
#%load_ext nb_black
#%load_ext lab_black

In [None]:
import climpred
import warnings

warnings.simplefilter("ignore")
import xarray as xr
import matplotlib.pyplot as plt

v = "sst"

from climpred import HindcastEnsemble

In [None]:
initialized = climpred.tutorial.load_dataset("NMME_hindcast_Nino34_sst")
obs = climpred.tutorial.load_dataset("NMME_OIv2_Nino34_sst")

hindcast = climpred.HindcastEnsemble(initialized).add_observations(obs)
hindcast

In [None]:
hindcast.sel(model="GFDL-CM2p5-FLOR-A06").plot()

In [None]:
with xr.set_options(display_style="text"):
    print(hindcast)

In [None]:
hindcast

## Additive mean bias removal

Typically, bias depends on lead-time and therefore should therefore also be removed depending on `lead`.

In [None]:
bias = hindcast.verify(
    metric="additive_bias", comparison="e2o", dim=[], alignment="same_verifs"
)

bias[v].plot(col="model")

In [None]:
# group bias by seasonality
seasonality = climpred.options.OPTIONS["seasonality"]
seasonality

In [None]:
bias.groupby(f"init.{seasonality}").mean()[v].plot(col="model")

An initial warm bias develops into a cold bias, especially in winter.

## `train_test_split`

{cite:t}`Risbey2021` demonstrate how important a clean separation of a `train` and a `test` period is for bias reduction. 

Implemented `train_test_split`s in `climpred`:

- `unfair`: completely overlapping `train` and `test` (climpred default)
- `unfair-cv`: overlapping `train` and `test` except for current `init`, which is [left out](https://en.wikipedia.org/wiki/Cross-validation_(statistics)#Leave-one-out_cross-validation) (set `cv="LOO"`)
- `fair`: no overlap between `train` and `test` (recommended)

In [None]:
# fair calculates bias for train_time/train_init and drops these indices from hindcast
hindcast.remove_bias(
    how="additive_mean",
    alignment=metric_kwargs["alignment"],
    train_test_split="fair",
    train_time=slice("1982", "1998"),
).sel(model="GFDL-CM2p5-FLOR-A06").plot()

In [None]:
import seaborn as sns

sns.set_palette("husl", skill_train_test_split.model.size)

In [None]:
metric_kwargs = dict(metric="rmse", alignment="same_verifs", dim="init", comparison="e2o", skipna=True)

In [None]:
train_test_split = ["unfair", "unfair-cv", "fair"] # different train_test_split methods to compare
skill_train_test_split = [hindcast.sel(time=slice("1982", "1998")).verify(**metric_kwargs)]
skill_train_test_split.append(
    hindcast.sel(time=slice("1982", "1998"))
    .remove_bias(
        how="additive_mean",
        alignment=metric_kwargs["alignment"],
        train_test_split="unfair",
    )
    .verify(**metric_kwargs)
)
skill_train_test_split.append(
    hindcast.sel(time=slice("1982", "1998"))
    .remove_bias(
        how="additive_mean",
        alignment=metric_kwargs["alignment"],
        train_test_split="unfair-cv",
        cv="LOO", # leave-one-out
    )
    .verify(**metric_kwargs)
)

skill_train_test_split.append(
    hindcast.remove_bias(
        how="additive_mean",
        alignment=metric_kwargs["alignment"],
        train_test_split="fair",
        train_time=slice("1982", "1998"),
    ).verify(**metric_kwargs)
)

skill_train_test_split = xr.concat(skill_train_test_split, "train_test_split")[v].assign_coords(train_test_split=["None"] + train_test_split)

In [None]:
skill_train_test_split.plot(hue="model", col="train_test_split", x="lead")
plt.ylim([0, 1.4])
plt.suptitle(f"NMME Nino3.4 SST {metric_kwargs['metric'].upper()} for different bias correction train_test splits", y=1.0)

plt.savefig("NMME_nino34_bias_correction_train_test_splits.png", bbox_inches="tight", dpi=300)

## Comparison of methods `how`

In [None]:
methods = [
    "additive_mean",
    # "multiplicative_std",
    "DetrendedQuantileMapping",
    "EmpiricalQuantileMapping",
    # "PrincipalComponents",
    # "LOCI",
    "QuantileDeltaMapping",
    "Scaling",
    "modified_quantile",
    "basic_quantile",
    # "gamma_mapping",
    # "normal_mapping",
]

In [None]:
import warnings

warnings.simplefilter("ignore")

In [None]:
# xclim.sdba requires pint units
hindcast._datasets["initialized"][v].attrs["units"] = "C"
hindcast._datasets["observations"][v].attrs["units"] = "C"

In [None]:
metric_kwargs["alignment"] = "same_inits"

In [None]:
metric_kwargs["reference"] = ["climatology", "persistence"]

In [None]:
skill_bias_reduction = [hindcast.sel(init=slice("1999", None)).verify(**metric_kwargs)]
for method in methods:
    skill_bias_reduction.append(
        hindcast.remove_bias(
            how=method,
            alignment=metric_kwargs["alignment"],
            train_test_split="fair",
            train_init=slice("1982", "1998"),
        ).verify(**metric_kwargs)
    )
skill_bias_reduction = xr.concat(skill_bias_reduction, "bias_correction")[
    v
].assign_coords(bias_correction=["None"] + methods)

In [None]:
# reference forecasts are unaffected by bias_correction
refs = skill_bias_reduction.drop_sel(skill="initialized").isel(
    bias_correction=0, model=0, drop=True
)
refs.plot(hue="skill")

In [None]:
fg = skill_bias_reduction.sel(skill="initialized").plot(
    hue="model", col="bias_correction", x="lead", col_wrap=4
)
plt.ylim([0, 1.4])
plt.suptitle(
    f"NMME Nino3.4 SST {metric_kwargs['metric'].upper()} for different bias reduction methods",
    y=1.0,
)
for ax in fg.axes.flatten():
    ax.plot(refs.sel(skill="persistence"), color="gray", ls=":")
    ax.plot(refs.sel(skill="climatology"), color="gray", ls="--")
plt.savefig("NMME_nino34_bias_correction.png", bbox_inches="tight", dpi=300)

#### How many months better than reference forecasts?

In [None]:
sns.set_palette("husl", skill_bias_reduction.skill.size)
skill_bias_reduction.sel(model="GFDL-CM2p5-FLOR-A06").plot(
    col="bias_correction", hue="skill", col_wrap=4
)

In [None]:
skill_bias_reduction.where(skill_bias_reduction.argmin("skill") == 0).notnull().sum(
    "lead"
).sel(skill="initialized", drop=True).astype(int).to_dataframe().unstack(0)