In [None]:
# this cell creates a set of files used by the examples below. On readthedocs, this cell is hidden.
import numpy as np
import pandas as pd
pd.plotting.register_matplotlib_converters()
import xarray as xr
from scipy.interpolate import interp1d
from mpl_toolkits.mplot3d import Axes3D

# time vector on 4 years
times = pd.date_range('2000-01-01', '2003-12-31', freq='D')
# temperature data as seasonal cycle -18 to 18
tas = xr.DataArray(-18 * np.cos(2 * np.pi * times.dayofyear / 365),
                   dims=('time',), coords={'time': times}, name ='tas',
                   attrs={'units': 'degC', 'standard_name': 'air_temperature', 
                          'long_name': 'Mean air temperature at surface'})

# write 10 members adding cubic-smoothed gaussian noise of wavenumber 43 and amplitude 20
# resulting temp will oscillate between -18 and 38
for i in range(10):
    tasi = tas + 20 * interp1d(np.arange(43), np.random.random((43,)), kind='quadratic')(np.linspace(0, 42, tas.size))
    tasi.name = 'tas'
    tasi.attrs.update(tas.attrs)
    tasi.attrs['title'] = f'tas of member {i:02d}'
    tasi.to_netcdf(f'ens_tas_m{i}.nc')

# Create 'toy' criteria selection data
np.random.normal(loc=3.5, scale=1.5, size=50)
#crit['delta_annual_tavg']
np.random.seed(0)
test = xr.DataArray(np.random.normal(loc=3, scale=1.5, size=100), dims=['realization']).assign_coords(horizon='2041-2070')
test = xr.concat((test,xr.DataArray(np.random.normal(loc=5.34, scale=2, size=100), dims=['realization']).assign_coords(horizon='2071-2100')), dim='horizon')
ds_crit = xr.Dataset()
ds_crit['delta_annual_tavg'] = test
test = xr.DataArray(np.random.normal(loc=5, scale=5, size=100), dims=['realization']).assign_coords(horizon='2041-2070')
test = xr.concat((test,xr.DataArray(np.random.normal(loc=10, scale=8, size=100), dims=['realization']).assign_coords(horizon='2071-2100')), dim='horizon')
ds_crit['delta_annual_prtot'] = test

test = xr.DataArray(np.random.normal(loc=0, scale=3, size=100), dims=['realization']).assign_coords(horizon='2041-2070')
test = xr.concat((test,xr.DataArray(np.random.normal(loc=2, scale=4, size=100), dims=['realization']).assign_coords(horizon='2071-2100')), dim='horizon')
ds_crit['delta_JJA_prtot'] = test

Ensembles
=========

An important aspect of climate models is that they are run multiple times with some initial perturbations to see how they replicate the natural variability of the climate. Through [xclim.ensembles](../api.rst#module-xclim.ensembles), xclim provides an easy interface to compute ensemble statistics on different members. Most methods perform checks and conversion on top of simpler `xarray` methods, providing an easier interface to use.

### create_ensemble
Our first step is to create an ensemble. This methods takes a list of files defining the same variables over the same coordinates and concatenates them into one dataset with an added dimension `realization`.

Using `xarray` a very simple way of creating an ensemble dataset would be :
```python
open_mfdataset(files, concat_dim='realization') 
```

However, this is only successful when the dimensions of all the files are identical AND only if the calendar type of each netcdf file is the same

xclim's `create_ensemble()` method overcomes these constraints selecting the common time period to all files and assigns a standard calendar type to the dataset. 

<div class="alert alert-info">

Input netcdf files still require equal spatial dimension size (e.g. lon, lat dimensions). <br>
If input data contains multiple cftime calendar types they must not be at daily frequency.

</div>

Given files all named `ens_tas_m[member number].nc`, we use `glob` to get a list of all those files.

In [None]:
import glob
import xclim as xc
import xarray as xr
# Set display to HTML sytle (for fancy output)
xr.set_options(display_style='html', display_width=50)

import matplotlib.pyplot as plt
%matplotlib inline

from xclim import ensembles

ens = ensembles.create_ensemble(glob.glob('ens_tas_m*.nc')).load()
ens.close()

In [None]:
plt.style.use('seaborn-dark')
plt.rcParams['figure.figsize'] = (13, 5)
ens.tas.plot(hue='realization')

In [None]:
ens.tas  # Attributes of the first dataset to be opened are copied to the final output

### Ensemble statistics
Beyond creating ensemble dataset the `xclim.ensembles` module contains functions for calculating statistics between realizations

**Ensemble mean, standard-deviation, max & min**

In the example below we use xclim's `ensemble_mean_std_max_min()` to calculate statistics across the 10 realizations in our test dataset. Output variables are created combining the original variable name `tas` with addtional ending indicating the statistic calculated on the realization dimension : `_mean`, `_stdev`, `_min`, `_max`

The resulting output now contains 4 derived variables from the original single variable in our ensemble dataset.

In [None]:
ens_stats = ensembles.ensemble_mean_std_max_min(ens)
ens_stats

### Ensemble percentiles

Here we use xclim's `ensemble_percentiles()` to calculate percentile values across the 10 realizations. 
The output has now a `percentiles` dimension instead of `realization`. Split variables can be created instead, by specifying `split=True` (the variable name `tas` will be appended with `_p{x}`). Compared to numpy's `percentile()` and xarray's `quantile()`, this method handles more efficiently dataset with invalid values and the chunking along the realization dimension (which is automatic when dask arrays are used).

In [None]:
ens_perc = ensembles.ensemble_percentiles(ens, values=[15, 50, 85], split=False)
ens_perc

In [None]:
fig, ax = plt.subplots()
ax.fill_between(ens_stats.time.values, ens_stats.tas_min, ens_stats.tas_max, alpha=0.3, label='Min-Max')
ax.fill_between(ens_perc.time.values, ens_perc.tas.sel(percentiles=15), ens_perc.tas.sel(percentiles=85), alpha=0.5, label='Perc. 15-85')
ax._get_lines.get_next_color()  # Hack to get different line
ax._get_lines.get_next_color()
ax.plot(ens_stats.time.values, ens_stats.tas_mean, linewidth=2, label='Mean')
ax.plot(ens_perc.time.values, ens_perc.tas.sel(percentiles=50), linewidth=2, label='Median')
ax.legend()

### Ensemble Reduction techniques
`xclim.ensembles` provides means of reducing the number of candidates in a sample to get a reasonable and representative spread of outcomes using a reduced number of candidates. By reducing the number of realizations in a strategic manner, we can significantly reduce the number of relizations that are considered, while maintaining statistical representation of original dataset. This is particularly useful when computation power or time is a factor.

**Selection Criteria** 
ADD TEXT HERE - pretend delta values IN THIS CASE WE HAVE A TOTAL OF 6 CRITERIA (3 delta variables with 2 horizons). We want to reduce the number of runs while still covering uncertainty in the 6 criteria

In [None]:
ds_crit

In [None]:
plt.style.use('seaborn-dark')
fig = plt.figure(figsize=(9,9))
ax = plt.axes(projection='3d')

for h in ds_crit.horizon:
    ax.scatter(ds_crit.sel(horizon=h).delta_annual_tavg, ds_crit.sel(horizon=h).delta_annual_prtot,
            ds_crit.sel(horizon=h).delta_JJA_prtot,label=f"delta {h.values}")

ax.set_xlabel('delta_annual_tavg (C)')
ax.set_ylabel('delta_annual_prtot (%)')
ax.set_zlabel('delta_JJA_prtot (%)')
plt.legend()

Ensemble reduction techniques require a 2D array with dimensions of `criteria` and `realization`

In [None]:
# Create 2d xr.DataArray containincg criteria values
crit=None
for h in ds_crit.horizon:
    for v in ds_crit.data_vars:
        if crit is None:
            crit = ds_crit[v].sel(horizon=h)
        else:
            crit = xr.concat((crit, ds_crit[v].sel(horizon=h)), dim='criteria')
crit.name= 'criteria'
crit.shape

#### **K-Means reduce ensemble**

The `kmeans_reduce_ensemble` works by grouping realizations into sub-groups based on the provided critera and retaining a representative `realization` per sub-group
* Example using `method = n_clusters`  
* The function  returns the `ids` of the selected realizations
* In this example we reduce the original 100 realizations to n=25


In [None]:
ids, cluster, fig_data = ensembles.kmeans_reduce_ensemble(data=crit, method={'n_clusters':25}, random_state=42, make_graph=True)
ds_crit.isel(realization = ids)

We can now compare the distribution of the selection versus the original ensemble

In [None]:
plt.style.use('seaborn-dark')
fig = plt.figure(figsize=(9,9))
ax = plt.axes(projection='3d')

for h in ds_crit.horizon:

    ax.scatter(ds_crit.sel(horizon=h,realization=ids).delta_annual_tavg, ds_crit.sel(horizon=h, realization=ids).delta_annual_prtot ,
            ds_crit.sel(horizon=h,realization=ids).delta_JJA_prtot,label=f"delta {h.values} - selected")

ax.set_xlabel('delta_annual_tavg (C)')
ax.set_ylabel('delta_annual_prtot (%)')
ax.set_zlabel('delta_JJA_prtot (%)')
plt.legend()

The function optionally produces a data dictionary for figure production of the associated R² profile. 
The function`ensembles.plot_rsqprofile` provides plotting for evaluating the proportion of total variance in climate realizations that is covered by the selection.  In this case ~78% of the total variance in original ensemble is covered by the selection

In [None]:
ensembles.plot_rsqprofile(fig_data)

Alternatively we can use `method = {'rsq_cutoff':Float}` or `method = {'rsq_optimize':None}`
* For example with `rsq_cutoff` we instead find the number of realizations needed to cover the provided R2 value

In [None]:
ids1, cluster1, fig_data1 = ensembles.kmeans_reduce_ensemble(data=crit, method={'rsq_cutoff':0.85}, random_state=42, make_graph=True)
ensembles.plot_rsqprofile(fig_data1)
ds_crit.isel(realization=ids1)