# Running an ensemble forecast

This example shows how to run an ensemble of 10 forecasts using the `SpeedyEns` class. 

For this example, we will execute these steps:

* Create an ensemble of models starting from a reference atmosphere. The same boundary conditions are used
  for each member.
* Perturb the initial temperature field with an uncorrelated gaussian noise with a  0.1$^\circ$K standard deviation.
  Note that this perturbations are neither physical or optimal with respect to the error growth! They are only used
  to display how to use the `SpeedyEns` class.
* Run the ensemble forecast for 3 months, with the first month considered as the "spinup" period. Hence, the results from the first month are discarded.
* Compute the error growth over time for the temperature and the U winds.

Let's get started. 

### Google colab fix

***IMPORTANT***

If you are running this notebook in Google Colab, uncomment and execute
the following lines to install pySPEEDY and its dependencies.


In [None]:
# !apt-get install libproj-dev proj-data proj-bin
# !apt-get install libgeos-dev libnetcdf-dev libnetcdff-dev

# !pip uninstall --yes shapely
# !pip install shapely --no-binary shapely
# !pip install cartopy

Temporary fix for https://github.com/SciTools/cartopy/issues/1869 proposed by @rcomer

In [None]:
# !wget https://raw.githubusercontent.com/SciTools/cartopy/master/tools/cartopy_feature_download.py
# !python cartopy_feature_download.py physical

Now we install pySPEEDY

In [None]:
#!pip install -v git+https://github.com/aperezhortal/pySPEEDY.git

End of Colab setup.

### Running the ensemble forecast

Let's run the ensemble forecast, keeping the output once a day (except during the spinup period).
To store the forecast in memory, we will use the `ModelCheckpoint` callback function.

In [None]:
%%time
from datetime import datetime

import numpy as np
from pyspeedy import SpeedyEns
from pyspeedy.callbacks import DiagnosticCheck, ModelCheckpoint

# Definitions
number_of_members = 10
start_date = datetime(1980, 1, 1)  # Simulation start date (datetime object).
end_date = datetime(1980, 2, 29)  # Simulation end date.
spinup_date = datetime(1980, 2, 1)  # End of spinup period.

# Create an instance of the speedy model.
model_ens = SpeedyEns(
    number_of_members,
    start_date=start_date,  # Simulation start date (datetime object).
    end_date=end_date,  # Simulation end date.
)
# At this point, each ensemble member contains an "empty" (not initialized) the model state.
# Let's initialized them.
# To do that, we will iterate over each member, set the boundary conditions, and add a random perturbation.
for member in model_ens:
    # Set the default boundary conditions derived from the ERA reanalysis.
    member.set_bc()
    # Add a perturbation to the temperature field in the grid space (not in the spectral one)
    member["t_grid"] += np.random.normal(0.0, 0.01, member["t_grid"].shape)
    # Since the prognostic variables used for the model integration are in the spectral space,
    # convert all the grid variables to the spectral space (temperature is among these variables).
    member.grid2spectral()

# Et voilà, our ensemble is initialized.

# Now, initialized the callback that will store the forecast in a dataframe with selected variables.
# The dataframe with the data is stored in the "dataframe" attribute of the
# created ModelCheckpoint instance.
model_checkpoints = ModelCheckpoint(
    interval=36,  # Every how many time steps we will save the output file. 36 -> once per day.
    verbose=True,  # Prind progress messages
    variables=None,  # Which variables to output. If none, save the most commonly used variables.
    spinup_date=spinup_date,  # End of spinup period
)

# We will also add a callback that run a diagnostic check every 36 steps,
# checking that some diagnostic values are within range.
# If the check fails, an exception is raised and the model stops.
diag_checks = DiagnosticCheck(interval=160)

# Run the model passing our callback.
# IMPORANT: This will the ensemble forecast in parallel, running a single member per thread.
model_ens.run(callbacks=[model_checkpoints, diag_checks])
# After the ensemble model is run, the model state contains the values from the last integration step.

The time series of the selected variables are stored in a dataframe inside the `model_checkpoints` object that we pass as a callback to the model run.

In [None]:
# Check the stored dataframe in the model checkout callback.
model_checkpoints.dataframe

## Growth of the ensemble spread over time

Let's compute the time series of the ensemble spread for temperature and the U wind.
The spread is computed following  ([Fortin et al., 2014](https://journals.ametsoc.org/view/journals/hydr/15/4/jhm-d-14-0008_1.xml)):

\begin{equation}
spr_{\phi} = \sqrt{ N^{-1} \sum\limits_{i}^N 
\left[ (M-1)^{-1}
\sum\limits_{m}^M \left( \phi_m(i) -\overline{ \phi(i) } \right)^2
\right]}
\end{equation}

where:

* $\phi$ denotes a variable (e.g. U or Temperature).
* $\phi_m(i)$ its value at grid point ``i'' for the $m^{th}$ member .
* the overbar indicates the ensemble average.
* $\sum\limits_{m}^M$: sumation over the ensemble members.
* $N^{-1} \sum\limits_{i}^N$: sumation over the grid points.

That is, the spread is computed byby first computing the variance over the ensemble for each grid point, and the averaging the variances for all grid points.

In [None]:
ens_dataset = model_checkpoints.dataframe

spr_ds = ens_dataset.var(dim="ens").mean(dim=["lev", "lat", "lon"]).apply(np.sqrt)
# Copy attributes from the ens_dataset
for var in spr_ds:
    spr_ds[var].attrs.update(**ens_dataset[var].attrs)

In [None]:
spr_ds["u"]

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(1, 2, figsize=(10, 4), dpi=300)
axs = axs.ravel()
labels = ["U wind", "Temperature"]
units = []
for ax, var, label in zip(axs, ["u", "t"], labels):
    spr_ds[var].plot(ax=ax)
    ax.set_title(f"Spread growth for {label}")
    units = spr_ds[var].attrs["units"]
    ax.set_ylabel(f"$spr_{label[0]}$ [{units}]")
    ax.set_xticks(ax.get_xticks()[:-1])  # Remove the last tick that looks bad.

Now let's plot the spread map for the surface temperature field at the beginning and the end of the simulation.

In [None]:
spr_ds = (
    ens_dataset.std(dim="ens").apply(np.sqrt).isel(lev=0)
)  # keep first level (surface)
# Copy attributes from the ens_dataset
for var in spr_ds:
    spr_ds[var].attrs.update(**ens_dataset[var].attrs)

In [None]:
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cartopy.feature import OCEAN
from cartopy.util import add_cyclic_point


fig, axs = plt.subplots(
    2, 1, subplot_kw=dict(projection=ccrs.PlateCarree()), figsize=(10, 8)
)

lon = spr_ds["lon"]

for i, time_idx in enumerate([0, -1]):
    ax = axs[i]
    plt.sca(ax)
    ax.set_title(
        f"Surface temperature spread [$^\circ$C]"
    )  # Add title for each subplot.
    ax.set_global()  # Set global extention
    ax.coastlines()  # Add coastlines
    ax.add_feature(OCEAN)  # Add oceans

    data_to_plot = spr_ds["t"].isel(time=time_idx)
    lon = spr_ds["lon"]
    lat = spr_ds["lat"]

    # Copy the longitude=0 degrees data to longitude=360 to have continuous plots
    data_to_plot, lon = add_cyclic_point(data_to_plot, coord=lon, axis=1)
    cs = ax.pcolormesh(
        lon.data,
        lat.data,
        data_to_plot.data,
        transform=ccrs.PlateCarree(),
        cmap="jet",
        shading="auto",
    )
    cbar = plt.colorbar(cs, label=f"Spread [$^\circ$C]")
_ = plt.subplots_adjust(wspace=0.05)