## DYffusion `Forecastor`, $F_{\theta}$, Evaluation.
* The aim of this notebook is to evaluate trained `DYffusion` models for an X-member ensemble.

* The notebook evaluates an entire dataset for `LPIPS`, `MSE`, `CRPS`, `SSR` and `CSI`.

* A `DYffusion` model is also evaluated per timestep to get roll-out metrics.

In [1]:
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from IPython.display import clear_output
from matplotlib.colors import ListedColormap
from scipy import io
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.regression import CriticalSuccessIndex
from tqdm import tqdm

from rainnow.src.dyffusion.utilities.evaluation import (
    evaluate_ensemble_crps,
    evaluate_ensemble_spread_skill_ratio,
)
from rainnow.src.normalise import PreProcess
from rainnow.src.utilities.loading import (
    get_model_ckpt_path,
    load_imerg_datamodule_from_config,
    load_model_state_dict,
)
from rainnow.src.utilities.utils import get_device
from rainnow.src.utilities.instantiators import instantiate_multi_horizon_dyffusion_model

  register_pytree_node(


#### `helpers.`

In [None]:
# ** DIR helpers **
BASE_PATH = "/teamspace/studios/this_studio"

CKPT_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/results/"
CONFIGS_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/src/dyffusion/configs/"

CKPT_DIR = "checkpoints"
CKPT_CFG_NAME = "hparams.yaml"
DATAMODULE_CONFIG_NAME = "imerg_precipitation.yaml"
# whether or not to get last.ckpt or to get the "best model" ckpt (the other one in the folder).
GET_LAST = False

# ** Dataloader Params **
BATCH_SIZE = 6  # this doesn't matter for inference.
NUM_WORKERS = 0

INPUT_SEQUENCE_LENGTH = 4
OUTPUT_SEQUENCE_LENGTH = 1

# ** plotting helpers **
# cmap = io.loadmat("../../src/utilities/cmaps/colormap.mat")
cmap = io.loadmat(f"{BASE_PATH}/DYffcast/rainnow/src/utilities/cmaps/colormap.mat")
rain_cmap = ListedColormap(cmap["Cmap_rain"])
global_params = {"font.size": 8}  # , "font.family": "Times New Roman"}
plt_params = {"wspace": 0.1, "hspace": 0.15}
ylabel_params = {"ha": "right", "va": "bottom", "labelpad": 1, "fontsize": 7.5}

# ** get device **
device = get_device()

#### `helpers.`

In [3]:
# ** DIR helpers **
BASE_PATH = "/teamspace/studios/this_studio"

CKPT_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/results/"
CONFIGS_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/src/dyffusion/configs/"

CKPT_DIR = "checkpoints"
CKPT_CFG_NAME = "hparams.yaml"
DATAMODULE_CONFIG_NAME = "imerg_precipitation.yaml"
# whether or not to get last.ckpt or to get the "best model" ckpt (the other one in the folder).
GET_LAST = False

# ** Dataloader Params **
BATCH_SIZE = 6
NUM_WORKERS = 0

# ** plotting helpers **
# cmap = io.loadmat("../../src/utilities/cmaps/colormap.mat")
cmap = io.loadmat(f"{BASE_PATH}/DYffcast/rainnow/src/utilities/cmaps/colormap.mat")
rain_cmap = ListedColormap(cmap["Cmap_rain"])
global_params = {"font.size": 8}  # , "font.family": "Times New Roman"}
plt_params = {"wspace": 0.1, "hspace": 0.15}
ylabel_params = {"ha": "right", "va": "bottom", "labelpad": 1, "fontsize": 7.5}

# ** get device **
device = get_device()

#### `Instantiate + Load in the datamodule.`

In [4]:
datamodule = load_imerg_datamodule_from_config(
    cfg_base_path=CONFIGS_BASE_PATH,
    cfg_name=DATAMODULE_CONFIG_NAME,
    overrides={
        "boxes": ["0,0", "1,0", "2,0", "2,1"],
        "window": 1,
        "horizon": 8,
        "prediction_horizon": 8,
        "sequence_dt": 1,
    },
)

datamodule.setup("test")

# set up the dataloaders.
BATCH_SIZE = datamodule.hparams["batch_size"]
test_dataloader = DataLoader(datamodule._data_test, batch_size=BATCH_SIZE, shuffle=False)

[2024-08-29 15:33:02][imerg_precipitation.py][INFO] --> training, validation & test using 4 (i, j) boxes: ['0,0', '1,0', '2,0', '2,1'].
[2024-08-29 15:33:02][imerg_precipitation.py][INFO] --> test data split: [202307010000, 202401010000]


[2024-08-29 15:33:08][torch_datasets.py][INFO] --> creating TEST tensor dataset.
[2024-08-29 15:33:08][normalise.py][INFO] --> pprocessing w/ percentiles (1st, 99th): [0.0, 5.670000076293945],  (min, max): [0.0, 3.23434630590838]
[2024-08-29 15:33:09][abstract_datamodule.py][INFO] -->  Dataset test size: 984


#### `Instantiate the preprocessor object`

In [5]:
# ** instantiate the preprocesser obj **
pprocessor = PreProcess(
    percentiles=datamodule.normalization_hparams["percentiles"],
    minmax=datamodule.normalization_hparams["min_max"],
)

[2024-08-29 15:33:09][normalise.py][INFO] --> pprocessing w/ percentiles (1st, 99th): [0.0, 5.670000076293945],  (min, max): [0.0, 3.23434630590838]


#### `Set up Eval metrics.`

In [6]:
# instantiate metrics.
lpips = LPIPS(reduction="mean", normalize=True).to(
    device
)  # set to True so that the function normalises to [-1, 1].
mse = MSELoss(reduction="mean")
csi_nodes = [2, 10, 18]
# need to get the nodes to the same scale as the data. See NB:  imerg_rainfall_classes.ipynb for rain classes + distributions.
normed_csi_nodes = pprocessor.apply_preprocessing(np.array(csi_nodes))
csi2 = CriticalSuccessIndex(threshold=normed_csi_nodes[0]).to(device)
csi10 = CriticalSuccessIndex(threshold=normed_csi_nodes[1]).to(device)
csi18 = CriticalSuccessIndex(threshold=normed_csi_nodes[-1]).to(device)

In [7]:
ckpt_dict = {
    "dyffusion-daftvdwg": "LCB(α=.6) | ALL boxes | Full Sequence | 20 Epochs | lr=3e-4",
    "dyffusion-fyxpjp65": "L1 | ALL boxes | Full Sequence | 20 Epochs | lr=3e-4",
}

#### `Evaluation (Sequence & t preds)`

In [1]:
# set eval params.
horizon = datamodule.hparams.horizon
num_ensemble = 10

eval_sequence_metrics, eval_metrics_per_t = {}, {}
with torch.no_grad():
    for ckpt_id in [
        "dyffusion-daftvdwg",
        # "dyffusion-fyxpjp65",
    ]:
        # ** instantiate a DYffusion model **
        model = instantiate_multi_horizon_dyffusion_model(
            ckpt_id=ckpt_id,
            ckpt_base_path=CKPT_BASE_PATH,
            diffusion_config_overrides={
                "interpolator_checkpoint_base_path": "/teamspace/studios/this_studio/irp-ds423/rainnow/results/"
                # "interpolator_checkpoint_base_path": "/Users/ds423/git_uni/irp-ds423/rainnow/results/interpolation_experiments"
            },
        )
        # load in model checkpoint.
        ckpt_path = get_model_ckpt_path(ckpt_id=ckpt_id, ckpt_base_path=CKPT_BASE_PATH, get_last=False)
        state_dict = load_model_state_dict(ckpt_path=ckpt_path, device=device)
        # load in weights and biases from the checkpoint model.
        model._model.load_state_dict(state_dict)
        # set model into eval mode.
        model.eval()
        model._model.eval()

        # -----------------------------------------------------------------------------------
        # ** 1. create eval dataset. Get preds and targets dims: (N, S, C, H, W) **
        # get a list of all the mean predictions per t. The list contains a {tn: mean(ens preds at t)} for each X.
        mu_all_preds, ens_all_preds, eval_targets_list = [], [], []
        for e, X in tqdm(
            enumerate(test_dataloader), total=len(test_dataloader), desc=f"Evaluating model {ckpt_id}"
        ):
            inputs = X["dynamics"].clone()
            # get initial condition.
            x0 = inputs[:, 0, :, :, :]

            # get targets and store them.
            targets = inputs[:, 1:, :, :, :]
            eval_targets_list.append(inputs[:, 1:, ...])

            # make n ensemble sampling predictions for each timestep, t.
            t_ens_preds = {f"t{t+1}_preds": [] for t in range(horizon)}
            for n in range(num_ensemble):
                print(f"ensemble {n}")
                preds = model.model.sample(initial_condition=x0)
                for k, pred in preds.items():
                    # handle data range if required.
                    if isinstance(model.model_final_layer_activation, nn.Tanh):
                        pred = (pred + 1) / 2
                    t_ens_preds[k].append(pred)
            # create a mean prediction dictionary from the ens t_preds,  with t: mean_prediction.
            mu_preds, ens_preds = {}, {}
            for k, v in t_ens_preds.items():
                mu_preds[k] = torch.mean(torch.cat([i for i in v], dim=1), dim=1).unsqueeze(1)
                ens_preds[k] = torch.cat([i for i in v], dim=1)
            mu_all_preds.append(mu_preds)
            ens_all_preds.append(ens_preds)

            # clear the output from the cell after each run.
            # the tqdm of DYffusion cloggs the UI.
            clear_output(wait=True)
        clear_output(wait=True)

        # assemble into a single predictions dataset for easier eval.
        # dataset w/ dims: (N, S, C, H, W). S is the number of prediction timesteps.
        eval_preds_list = []
        for n, mu_t_preds in enumerate(mu_all_preds):
            eval_preds_list.append(torch.cat([v for k, v in mu_t_preds.items()], dim=1).unsqueeze(-3))
        eval_preds = torch.cat([i for i in eval_preds_list], dim=0)
        eval_targets = torch.cat([i for i in eval_targets_list], dim=0)
        assert eval_preds.size() == eval_targets.size()

        # also get a dataset for the ensemble metrics.
        eval_ens_preds_list = []
        for n, ens_t_preds in enumerate(ens_all_preds):
            eval_ens_preds_list.append(
                torch.stack([v for k, v in ens_t_preds.items()]).permute(1, 2, 0, 3, 4).unsqueeze(-3)
            )
        eval_ens_preds = torch.cat([i for i in eval_ens_preds_list], dim=0)
        assert eval_ens_preds.size(1) == num_ensemble
        assert eval_ens_preds.size(2) == horizon
        # -----------------------------------------------------------------------------------

        # -----------------------------------------------------------------------------------
        # ** 2. get sequence metrics. Compare pred sequence to target sequence. **
        l2_score = 0
        lpips_score = 0
        csi2_score = 0  # low rain.
        csi10_score = 0  # mid rain.
        csi18_score = 0  # heavy rain.
        # probabilistic metrics.
        crps_score = 0
        raw_crps_score = 0
        ssr_score = 0
        for idx in range(eval_preds.size(0)):
            # get preds and targets.
            _pred = eval_preds[idx, ...]
            _ens_preds = eval_ens_preds[idx, ...]
            _target = eval_targets[idx, ...]

            # get metrics.
            l2_score += mse(_pred.to(device), _target.to(device))
            # lpips score. Inputs need to have 3 channels.
            lpips_score += lpips(
                torch.clamp(_pred.expand(-1, 3, -1, -1), 0, 1).to(device),
                torch.clamp(_target.expand(-1, 3, -1, -1), 0, 1).to(device),
            )
            # csi score at different thresholds.
            csi2_score += csi2(_pred.to(device), _target.to(device))
            csi10_score += csi10(_pred.to(device), _target.to(device))
            csi18_score += csi18(_pred.to(device), _target.to(device))
            # probabilistic metrics.
            crps_score += evaluate_ensemble_crps(ensemble_predictions=_ens_preds, targets=_target)
            raw_crps_score += evaluate_ensemble_crps(
                ensemble_predictions=pprocessor.reverse_processing(_ens_preds),
                targets=pprocessor.reverse_processing(_target),
            )
            ssr_score += evaluate_ensemble_spread_skill_ratio(
                ensemble_predictions=_ens_preds.numpy(), targets=_target.numpy()
            )

        eval_sequence_metrics[ckpt_id] = {
            "MSE": l2_score.item() / eval_preds.size(0),
            "lpips": lpips_score.item() / eval_preds.size(0),
            "csi2": csi2_score.item() / eval_preds.size(0),
            "csi10": csi10_score.item() / eval_preds.size(0),
            "csi18": csi18_score.item() / eval_preds.size(0),
            "crps": crps_score / eval_preds.size(0),
            "raw_crps": raw_crps_score / eval_preds.size(0),
            "ssr": ssr_score / eval_preds.size(0),
        }
        # -----------------------------------------------------------------------------------

        # create df, format it and export it to a .csv.
        df_results = pd.DataFrame(eval_sequence_metrics).T
        df_results.to_csv("dyffusion_eval_metrics_v2.csv")

        # -----------------------------------------------------------------------------------
        # ** 3. get timestep metrics. Compare pred t sequence to target t sequence**
        # csi stores.
        csi2_score_t = torch.zeros(eval_preds.size(1)).to(device)
        csi10_score_t = torch.zeros(eval_preds.size(1)).to(device)
        csi18_score_t = torch.zeros(eval_preds.size(1)).to(device)
        # perceptual loss stores.
        lpips_score_t = torch.zeros(eval_preds.size(1)).to(device)

        for idx in range(eval_preds.size(0)):
            # get preds and targets.
            _pred = eval_preds[idx, ...]
            _target = eval_targets[idx, ...]

            for t in range(eval_preds.size(1)):
                # loop through all the ts and compute the relevant CSI scores.
                csi2_score_t[t] += csi2(_target[t].to(device), _pred[t].to(device))
                csi10_score_t[t] += csi10(_target[t].to(device), _pred[t].to(device))
                csi18_score_t[t] += csi18(_target[t].to(device), _pred[t].to(device))
                lpips_score_t[t] += lpips(
                    torch.clamp(_target[t].expand(1, 3, -1, -1), 0, 1).to(device),
                    torch.clamp(_pred[t].expand(1, 3, -1, -1), 0, 1).to(device),
                )
        # normalise the scores.
        eval_metrics_per_t[ckpt_id] = {
            "csi2_t": csi2_score_t / len(eval_preds),
            "csi10_t": csi10_score_t / len(eval_preds),
            "csi18_t": csi18_score_t / len(eval_preds),
            "lpips_t": lpips_score_t / len(eval_preds),
        }

        # create df, format it and export it to a .csv.
        df_results_per_t = pd.DataFrame(eval_metrics_per_t).T
        df_results_per_t.to_csv("dyffusion_eval_metrics_per_time_v2.csv")
        # -----------------------------------------------------------------------------------

In [9]:
df_results

Unnamed: 0,MSE,crps,csi10,csi18,csi2,lpips,raw_crps,ssr
dyffusion-fyxpjp65,0.001498,0.009707,0.067546,0.023004,0.227119,0.32259,0.275957,0.045327


#### `Get metrics / t for CSI + LPIPS`

In [10]:
# get csi + lpips dfs.
df_all = {}
for k, v in eval_metrics_per_t.items():
    dfs = {}
    for metric, values in v.items():
        dfs[metric] = pd.DataFrame(
            data=[[i.item() for i in values]], columns=[f"t{i+1}" for i in range(len(values))]
        )
    df_all[k] = dfs

df_csi2 = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])
df_csi10 = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])
df_csi18 = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])
df_lpips = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])

for model_name, metrics in df_all.items():
    csi2_values = metrics["csi2_t"].iloc[0].values
    csi10_values = metrics["csi10_t"].iloc[0].values
    csi18_values = metrics["csi18_t"].iloc[0].values
    lpips_values = metrics["lpips_t"].iloc[0].values

    df_csi2.loc[model_name] = csi2_values
    df_csi10.loc[model_name] = csi10_values
    df_csi18.loc[model_name] = csi18_values
    df_lpips.loc[model_name] = lpips_values

In [11]:
df_csi2

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
dyffusion-fyxpjp65,0.509026,0.33503,0.259788,0.199946,0.15968,0.133388,0.114014,0.095876


In [12]:
df_csi10

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
dyffusion-fyxpjp65,0.250203,0.089917,0.049971,0.027504,0.017135,0.012819,0.010001,0.005476


In [13]:
df_csi18

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
dyffusion-fyxpjp65,0.098247,0.020488,0.007073,0.002692,0.001974,0.001286,0.000923,0.000232


In [14]:
df_lpips

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
dyffusion-fyxpjp65,0.129959,0.244934,0.304729,0.348849,0.373035,0.386031,0.39239,0.403022


### END OF SCRIPT.