# Forecast evaluation

This notebook evaluates the machine-learning-based S2S preciptation forecasts against the ECMWF baseline and a monthly precipitation climatology.

The evaluation expects the forecasts results to be organized in separate folders containing the forecasts in files separated by initialization date. The files should follow the filename pattern ``*_YYYYMMDD_00_00.nc`` where ``YYYYMMDD`` specifies the initialization date of the forecast, i.e., the first day for which a forecast is produced. Each file should contain the predicted daily precipitation accumulations on the MERRA grid as three-dimensional field with dimensions ``step``, ``latitude``, ``longitude``.


The ``DATA_PATH`` variable defined in the cell below should be set to the path containing the ``training_data`` and ``test_data`` folders of the 2S2-Precip dataset.


In [None]:
from pathlib import Path
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

DATA_PATH = Path("/edata1/simon/chimp/s2s/")

## Helper functions

The code below defines helper functions to load forecasts results for a given initialization time.

In [None]:
from pathlib import Path
from typing import Dict
from pansat.time import to_datetime

def find_file(date: np.datetime64, path: Path) -> Path:
    """
    Find a result file for a given date.

    Args:
        date: The initialization dat.
        path: The folder containing the files.

    Result:
        A path object pointing to the file expected to contain the results for the given date.
    """
    path = Path(path)
    date = to_datetime(date)
    date_str = date.strftime("%Y%m%d_%H_%M")
    pattern =  f"*_{date_str}.nc"
    files = sorted(list(path.glob(pattern)))
    if len(files) == 0:
        raise RuntimeError(
            f"Didn't find any file in '{path}' for the given date {date}."
        )
    if len(files) > 1:
        raise RuntimeError(
            f"Found multiple files in '{path}' for the given date {date}."
        )
    return files[0]

def load_reference_data(start_date: np.datetime64, path: Path, n_days: int = 28):
    """
    Load the reference data for forecasts initialized at a given date.

    Args:
        start_date: The initialization time of the forecast.
        path: Path to the directory containing the reference data.
        n_days: The number of consecutuve days for which to load the reference data starting
             with the start date.

    Return:
        An xarray.Dataset containing the reference data fields.
    """
    ref_data = []
    date = start_date
    steps = np.arange(0, n_days).astype("timedelta64[D]")
    for date in start_date + steps:
        data = xr.load_dataset(find_file(date, path))
        precip = data.precipitation.data
        precip[precip < 0.0] = np.nan
        ref_data.append(data)
    ref_data = xr.concat(ref_data, "step")
    ref_data["step"] = steps.astype("timedelta64[ns]")
    ref_data["time"] = start_date.astype("datetime64[ns]")
    return ref_data
    

### Load or calculate climatology

In [None]:
import numpy as np
import xarray as xr
from tqdm import tqdm
from chimp.areas import MERRA

def calculate_precip_climatology(precip_data: Path) -> xr.Dataset:
    """
    Calculate precipitation climatology from daily accumulation files.

    Args:
        precip_data: A path object pointing to the reference precipitation data from which
             to compute the climatolog.

    Return:
        An xarray.Dataset containing the monthly climatology.
    """
    precip_data = Path(precip_data)
    
    files = sorted(list(precip_data.glob("*_20??????_??_??.nc")))

    tot_precip = np.zeros((12,) + MERRA[8].shape)
    counts = np.zeros((12,) + MERRA[8].shape)

    print("Calculating climatology:")
    for path in tqdm(files):
        with xr.open_dataset(path) as data:

            data = data.transpose("latitude", "longitude")
            month = int(path.name.split("_")[1][4:-2])
            lons = data.longitude.data
            lats = data.latitude.data
            
            precipitation = data.precipitation.data
            valid = np.isfinite(precipitation) * (precipitation >= 0)
            tot_precip[month - 1][valid] += precipitation[valid]
            counts[month - 1] += valid.astype("float32")

    climatology = xr.Dataset({
        "month": (("month",), np.arange(12) + 1),
        "longitude": (("longitude",), lons),
        "latitude": (("latitude",), lats),
        "precipitation": (("month", "latitude", "longitude"), tot_precip / counts)
    })
    return climatology
                                           

Below, a montly precipitation climatology from a file called ``climatology.nc`` is read, or, if such a file is not available the climatology is calculated from the training data of the dataset.

In [None]:
climatology_file = Path("climatology.nc")
if climatology_file.exists():
    climatology = xr.load_dataset("climatology.nc")
else:
    climatology = calculate_precip_climatology(DATA_PATH / "training_data" / "daily_precip")
    climatology.to_netcdf("climatology.nc")

### Visualize predicted weekly accumulations

The functions defined below visualize the baseline and model predictions together with the reference data.

In [None]:
from typing import Dict
from pathlib import Path
from cartopy import crs as ccrs
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LogNorm, Normalize
from pansat.time import to_datetime

def plot_predictions(
    date: np.datetime64,
    reference_data_path: Path,
    prediction_paths: Dict[str, Path],
) -> plt.Figure:
    """
    Plot predicted precipitation.

    Args:
        date: The initialization data of the forecasts to visualize.
        reference_data_path: The path containing the reference data.
        prediction_paths: A dictionary mapping dataset names to folder containing the corresponding
            forecast results.

    Return:
        A matplotlib.Figure object containing the visualization of the forecasts.
    """
    p_height = 2 
    p_width = 4 
    n_rows = 1 + len(prediction_paths)

    fig = plt.figure(figsize=(4 * p_width, n_rows * p_height))
    gs = GridSpec(n_rows, 6, width_ratios = [0.1] + [1.0] * 4 + [0.1])

    ref_data = load_reference_data(date, reference_data_path, n_days=27).resample(step="7D").mean()
    ref_data = ref_data.transpose("step", "latitude", "longitude")
    norm = LogNorm(1e-1, 1e2)
    norm = Normalize(0, 10)
    lats = ref_data.latitude.data
    lat_mask = (lats > -60) * (lats < 60)

    crs = ccrs.Robinson()
    pc = ccrs.PlateCarree()
    cmap = "Blues"

    for week in range(4):
        ax = fig.add_subplot(gs[0, week + 1], projection=crs)
        lons = ref_data.longitude.data
        lats = ref_data.latitude.data
        ax.pcolormesh(lons, lats, ref_data.precipitation[{"step": week}].data, transform=pc, norm=norm, cmap=cmap)
        ax.set_title(f"Week {week + 1}", loc="center")
        ax.coastlines(color="grey")
        
    ax = fig.add_subplot(gs[0, 0])
    ax.set_axis_off()
    ax.text(0, 0, s="Reference", ha="center", va="center", rotation=90, fontsize=16)
    ax.set_ylim(-2, 2)

    lons = ref_data.longitude.data
    lats = ref_data.latitude.data

    for row, (name, path) in enumerate(prediction_paths.items()):

        prediction_file = find_file(date, path)
        predictions = xr.load_dataset(prediction_file)
            
        predictions = predictions.resample(step="7D").mean()
        #if "latitude" in predictions:
        #    predictions = predictions[{"latitude": lat_mask}]
        #else:
        #    predictions = predictions[{"y": lat_mask}]
        
        for week in range(4):
            
            ax = fig.add_subplot(gs[1 + row, week + 1], projection=crs)
            m = ax.pcolormesh(lons, lats, predictions.precipitation[{"step": week}].data, transform=pc, norm=norm, cmap=cmap)

            precip_ref = ref_data.precipitation[{"step": week}].data
            if "precipitation_em" in predictions:
                precip_pred = predictions.precipitation_em[{"step": week}].data
            else:
                precip_pred = predictions.precipitation[{"step": week}].data
            valid = np.isfinite(precip_ref) * np.isfinite(precip_pred)
            ax.coastlines(color="grey")
            
        ax = fig.add_subplot(gs[1 + row, 0])
        ax.set_axis_off()
        ax.text(0, 0, s=name, ha="center", va="center", rotation=90, fontsize=16)
        ax.set_ylim(-2, 2)

    cax = fig.add_subplot(gs[:, -1])
    plt.colorbar(m, cax=cax, label="Mean daily accumulated precipitation [mm]")

    date = to_datetime(date)
    fig.suptitle(date.strftime("%Y-%m-%d"), fontsize=16)
    return fig

In [None]:
plt.style.use("../../windset.mplstyle")

fig = plot_predictions(
    np.datetime64("2020-01-09"),
    DATA_PATH / "test_data" / "daily_precip",
    {
        "ECMWF": DATA_PATH / "test_data" / "s2s_ecmwf",
        "UMKO": DATA_PATH / "test_data" / "s2s_ukmo",
        "ML Forecast": DATA_PATH / "results" / "resnext_new",
    }
)
fig.savefig("example_predictions.png", bbox_inches="tight", dpi=200)

In [None]:
from cartopy import crs as ccrs
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LogNorm
from pansat.time import to_datetime

def plot_predictions_by_day(
    date: np.datetime64,
    reference_data_path: Path,
    prediction_paths: Dict[str, Path],
    n_days = 4
):
    p_height = 4 
    p_width = 8 
    n_rows = 1 + len(prediction_paths)

    fig = plt.figure(figsize=(n_days * p_width, n_rows * p_height))
    gs = GridSpec(n_rows, n_days + 2, width_ratios = [0.1] + [1.0] * n_days + [0.1])

    ref_data = load_reference_data(date, reference_data_path)
    ref_data = ref_data.transpose("step", "latitude", "longitude")
    lats = ref_data.latitude.data
    lat_mask = (lats > -60) * (lats < 60)
    ref_data = ref_data[{"latitude": lat_mask}]
    
    norm = LogNorm(1e-1, 1e2)

    crs = ccrs.Robinson()
    pc = ccrs.PlateCarree()
    cmap = "Blues"

    for day in range(n_days):
        ax = fig.add_subplot(gs[0, day + 1], projection=crs)
        lons = ref_data.longitude.data
        lats = ref_data.latitude.data
        ax.pcolormesh(lons, lats, ref_data.precipitation[{"step": day}].data, transform=pc, norm=norm, cmap=cmap)
        ax.set_title(f"Day {day + 1}")
        ax.coastlines(color="grey")

    lons = ref_data.longitude.data
    lats = ref_data.latitude.data

    ax = fig.add_subplot(gs[0, 0])
    ax.set_axis_off()
    ax.text(0, 0, s="Reference", ha="center", va="center", rotation=90, fontsize=16)
    ax.set_ylim(-2, 2)

    for row, (name, path) in enumerate(prediction_paths.items()):

        prediction_file = find_file(date, path)
        predictions = xr.load_dataset(prediction_file).resample(step="1D").mean()
        if "latitude" in predictions:
            predictions = predictions[{"latitude": lat_mask}]
        else:
            predictions = predictions[{"y": lat_mask}]
            
        for day in range(n_days):
            
            ax = fig.add_subplot(gs[1 + row, day + 1], projection=crs)
            m = ax.pcolormesh(lons, lats, predictions.precipitation[{"step": day}].data, transform=pc, norm=norm, cmap=cmap)

            precip_ref = ref_data.precipitation[{"step": day}].data
            precip_pred = predictions.precipitation[{"step": day}].data
            valid = np.isfinite(precip_ref) * np.isfinite(precip_pred)
            corr = np.corrcoef(precip_ref[valid], precip_pred[valid])[0, 1]
            ax.set_title(f"Corr. coef.: {corr}")
            ax.coastlines(color="grey")
            
        ax = fig.add_subplot(gs[1 + row, 0])
        ax.set_axis_off()
        ax.text(0, 0, s=name, ha="center", va="center", rotation=90, fontsize=16)
        ax.set_ylim(-2, 2)

    cax = fig.add_subplot(gs[:, -1])
    plt.colorbar(m, cax=cax, label="Daily accumulated precipitation [mm]")

## Numerical evaluation

In [None]:
from pathlib import Path
from chimp.utils import get_date
from datetime import datetime


def evaluate_forecasts(
    climatology: xr.Dataset,
    reference_data_path: Path,
    forecast_paths: Dict[str, Path],
    weekly=False
):
    """
    This function calculates forecast bias, MSE and correlation coefficient and their spatial distributions for
    several provided forecast results.
    """

    n_steps = 28
    if weekly:
        n_steps = 4
    
    ref_files = sorted(list(Path(reference_data_path).glob("*_????????_??_??.nc")))
    ref_dates = set(map(lambda x: get_date(x.name), ref_files))
    for path in forecast_paths.values():
        print(path)
        if isinstance(path, Path):
            files = sorted(list(Path(path).glob("*_????????_??_??.nc")))
            dates = set(map(lambda x: get_date(x.name), files))
            ref_dates = ref_dates.intersection(dates)

    ref_dates = sorted(list(ref_dates))
    results = {}
    
    for name, path in forecast_paths.items():

        pred_sum = np.zeros((n_steps, 240, 576))
        pred2_sum = np.zeros((n_steps, 240, 576))
        target_sum = np.zeros((n_steps, 240, 576))
        target2_sum = np.zeros((n_steps, 240, 576))
        predtarget_sum = np.zeros((n_steps, 240, 576))
        counts = np.zeros((n_steps, 240, 576))

        for date in tqdm(sorted(list(ref_dates))):

            try:
                ref_data = load_reference_data(date, reference_data_path)
            except RuntimeError:
                continue
            if weekly:
                ref_data = ref_data.resample(step="7D").mean()
            ref_data = ref_data.transpose("step", "latitude", "longitude")

            steps = ref_data.step.data
            lons = ref_data.longitude.data
            lats = ref_data.latitude.data
            lat_mask = (lats > -60) * (lats < 60)
            lats = lats[lat_mask]
            ref_data = ref_data[{"latitude": lat_mask}]

            times = ref_data.time + ref_data.step
            months = times.dt.month
            
            if isinstance(path, xr.Dataset):
                predictions = path
                predictions = predictions[{"month": months - 1}]
            else:
                prediction_file = find_file(date, path)
                predictions = xr.load_dataset(prediction_file)
                
            if weekly:
                predictions = predictions.resample(step="7d").mean()

            if "latitude" in predictions:
                predictions = predictions[{"latitude": lat_mask}]
            else:
                predictions = predictions[{"y": lat_mask}]
            
            for step in range(n_steps):
                precip_ref = ref_data.precipitation.data[step]
                if "precipitation_em" in predictions:
                    precip_pred = predictions.precipitation_em.data[step]
                else:
                    precip_pred = predictions.precipitation.data[step]
                
                valid = np.isfinite(precip_ref) * (precip_ref >= 0.0) * np.isfinite(precip_pred)
                pred = precip_pred
                target = precip_ref
    
                pred_sum[step] += np.where(valid, pred, 0.0)
                pred2_sum[step] += np.where(valid, pred ** 2, 0.0)
                target_sum[step] += np.where(valid, target, 0.0)
                target2_sum[step] += np.where(valid, target ** 2, 0.0)
                predtarget_sum[step] += np.where(valid, pred * target, 0.0)
                counts[step] += valid.astype("float32")

        pred_mean = pred_sum / counts
        target_mean = target_sum / counts
        pred_var = pred2_sum / counts - pred_mean ** 2
        target_var = target2_sum / counts - target_mean ** 2
        mse = (pred2_sum - 2.0 * predtarget_sum + target2_sum) / counts
        corr = (predtarget_sum / counts - pred_mean * target_mean) / np.sqrt(pred_var * target_var)
        bias = pred_mean - target_mean

        weights = np.broadcast_to(np.cos(np.deg2rad(lats))[..., None], pred_mean.shape).copy()
        print(weights.min(), weights.max())

        step_counts = (weights * counts).sum(axis=(1, 2))
        
        pred_mean_tot = (weights * pred_sum).sum(axis=(1, 2)) / step_counts
        target_mean_tot = (weights * target_sum).sum(axis=(1, 2)) / step_counts
        pred_var_tot = (weights * pred2_sum).sum(axis=(1, 2)) / step_counts - pred_mean_tot ** 2
        target_var_tot = (weights * target2_sum).sum(axis=(1, 2)) / step_counts - target_mean_tot ** 2
        mse_tot = (weights * (pred2_sum - 2.0 * predtarget_sum + target2_sum)).sum(axis=(1, 2)) / step_counts
        corr_tot = ((weights * predtarget_sum).sum(axis=(1, 2)) / step_counts - pred_mean_tot * target_mean_tot) / np.sqrt(pred_var_tot * target_var_tot)
        bias_tot = pred_mean_tot - target_mean_tot

        results[name] = xr.Dataset({
            "latitude": (("latitude",), lats),
            "longitude": (("longitude"), lons),
            "step": (("step"), steps),
            "bias": (("step", "latitude", "longitude"), bias),
            "correlation_coef": (("step", "latitude", "longitude"), corr),
            "mean_squared_error": (("step", "latitude", "longitude"), mse),
            "bias_tot": (("step",), bias_tot),
            "correlation_coef_tot": (("step",), corr_tot),
            "mean_squared_error_tot": (("step",), mse_tot)
        })
        
    return results


In [None]:
DATA_PATH

In [None]:
results = evaluate_forecasts(
    climatology,
    DATA_PATH / "test_data" / "daily_precip",
    {
        "Climatology": climatology,
        "ECMWF": DATA_PATH / "test_data" / "s2s_ecmwf",
        "UKMO": DATA_PATH / "test_data" / "s2s_ukmo",
        "resnext-16": DATA_PATH / "results" / "resnext_new"
    },
    weekly=False
)

In [None]:
import matplotlib as mpl
from matplotlib.gridspec import GridSpec
mpl.style.use("../../windset.mplstyle")

fig = plt.figure(figsize=(16, 4))
gs = GridSpec(1, 4, width_ratios=[1.0, 1.0, 1.0, 0.3], wspace=0.3)
lead_time = np.arange(1, 29)

ax = fig.add_subplot(gs[0, 0])
#ax.plot(lead_time, results["ECMWF"].bias_tot)
#ax.plot(lead_time, results["UKMO"].bias_tot)
ax.plot(lead_time, results["resnext-16"].bias_tot)
ax.plot(lead_time, results["Climatology"].bias_tot, c="k", ls="--")
ax.set_xlabel("Lead time [d]")
ax.set_ylabel("Bias [mm]")
ax.set_title("(a) Bias", loc="left")
ax.set_ylim(-0.5, 0.5)
ax.set_xlim(0, 28)

ax = fig.add_subplot(gs[0, 1])
ax.plot(lead_time, results["ECMWF"].correlation_coef_tot, label="ecmwf")
ax.plot(lead_time, results["UKMO"].correlation_coef_tot, label="UK")
ax.plot(lead_time, results["resnext-16"].correlation_coef_tot, label="Neural network forecast")
ax.plot(lead_time, results["Climatology"].correlation_coef_tot, label="Climatology", c="k", ls="--")
ax.set_xlabel("Lead time [d]")
ax.set_ylabel("Correlation coef.",)
ax.set_title("(b) Correlation coef.", loc="left")
ax.set_ylim(0.0, 1.0)
ax.set_xlim(0, 28)

ax = fig.add_subplot(gs[0, 2])
handles = []
handles += ax.plot(lead_time, results["Climatology"].mean_squared_error_tot, label="Monthly climatology", ls="--", c="k")
handles += ax.plot(lead_time, results["ECMWF"].mean_squared_error_tot, label="ECMWF")
handles += ax.plot(lead_time, results["UKMO"].mean_squared_error_tot, label="UKMO")
handles += ax.plot(lead_time, results["resnext-16"].mean_squared_error_tot, label="ML forecast")
ax.set_xlabel("Lead time [d]")
ax.set_ylabel("Mean sqared error [mm$^2$]")
ax.set_title("(c) Mean squared error", loc="left")
ax.set_xlim(0, 28)
ax.set_ylim(50, 130)

ax = fig.add_subplot(gs[0, 3])
ax.set_axis_off()
ax.legend(handles=handles, loc="center")

fig.savefig("ltpf_baseline_metrics.pdf", bbox_inches="tight")