## `ConvLSTM` Evaluation 101
* the aim of this notebook is to evaluate trained ConvLSTM models.
* 2 evaluations take place in this NB, the first is a sequence to sequence evaluation and the second is CSI (and LPIPS) / time.

In [1]:
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import xarray as xr
import xskillscore as xs
from livelossplot import PlotLosses
from matplotlib.colors import ListedColormap
from scipy import io
from torch.nn import L1Loss, MSELoss
from torch.utils.data import DataLoader
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.regression import CriticalSuccessIndex
from tqdm import tqdm

from rainnow.src.conv_lstm_utils import (
    IMERGDataset,
    create_eval_loader,
    plot_predicted_sequence,
    save_checkpoint,
    train,
    validate,
)
from rainnow.src.loss import CBLoss, LPIPSMSELoss
from rainnow.src.models.conv_lstm import ConvLSTMModel
from rainnow.src.normalise import PreProcess
from rainnow.src.utilities.loading import load_imerg_datamodule_from_config
from rainnow.src.utilities.utils import (
    get_device,
    transform_0_1_to_minus1_1,
    transform_minus1_1_to_0_1,
)

#### `helpers`

In [2]:
# ** plotting helpers **
cmap = io.loadmat("../../src/utilities/cmaps/colormap.mat")
# cmap = io.loadmat("/teamspace/studios/this_studio/irp-ds423/rainnow/src/utilities/cmaps/colormap.mat")
rain_cmap = ListedColormap(cmap["Cmap_rain"])
global_params = {"font.size": 8}  # , "font.family": "Times New Roman"}
plt_params = {"wspace": 0.1, "hspace": 0.15}
ylabel_params = {"ha": "right", "va": "bottom", "labelpad": 1, "fontsize": 7.5}

# ** get device **
device = get_device()

No GPU available! (device = cpu)


In [3]:
# ** DIR helpers **
CKPT_BASE_PATH = "/Users/ds423/git_uni/irp-ds423/rainnow/results/"
CONFIGS_BASE_PATH = "/Users/ds423/git_uni/irp-ds423/rainnow/src/dyffusion/configs/"
# CKPT_BASE_PATH = "/teamspace/studios/this_studio/irp-ds423/rainnow/results/"
# CONFIGS_BASE_PATH = "/teamspace/studios/this_studio/irp-ds423/rainnow/src/dyffusion/configs/"

CKPT_DIR = "checkpoints"
CKPT_CFG_NAME = "hparams.yaml"
DATAMODULE_CONFIG_NAME = "imerg_precipitation.yaml"
# whether or not to get last.ckpt or to get the "best model" ckpt (the other one in the folder).
GET_LAST = False

# ** Dataloader Params **
BATCH_SIZE = 12
NUM_WORKERS = 0

INPUT_SEQUENCE_LENGTH = 4
OUTPUT_SEQUENCE_LENGTH = 1

#### `Instantiate + Load in the datamodule`

In [4]:
datamodule = load_imerg_datamodule_from_config(
    cfg_base_path=CONFIGS_BASE_PATH,
    cfg_name=DATAMODULE_CONFIG_NAME,
    overrides={
        "boxes": ["0,0", "1,0", "2,0", "2,1"],
        "window": 1,
        "horizon": 8,
        "prediction_horizon": 8,
        "sequence_dt": 1,
    },
)

datamodule.setup("test")

[2024-08-29 10:49:29][imerg_precipitation.py][INFO] --> training, validation & test using 4 (i, j) boxes: ['0,0', '1,0', '2,0', '2,1'].
[2024-08-29 10:49:29][imerg_precipitation.py][INFO] --> test data split: [202307010000, 202401010000]


[2024-08-29 10:49:35][torch_datasets.py][INFO] --> creating TEST tensor dataset.
[2024-08-29 10:49:36][normalise.py][INFO] --> pprocessing w/ percentiles (1st, 99th): [0.0, 5.670000076293945],  (min, max): [0.0, 3.23434630590838]
[2024-08-29 10:49:36][abstract_datamodule.py][INFO] -->  Dataset test size: 984


#### `Create the test_dataset`

In [5]:
# create the datasets.
test_dataset = IMERGDataset(
    datamodule, "test", sequence_length=INPUT_SEQUENCE_LENGTH, target_length=OUTPUT_SEQUENCE_LENGTH
)

test_loader = DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False
)

#### `Instantiate the preprocessor object`

In [6]:
# ** instantiate the preprocesser obj **
pprocessor = PreProcess(
    percentiles=datamodule.normalization_hparams["percentiles"],
    minmax=datamodule.normalization_hparams["min_max"],
)

[2024-08-29 10:49:36][normalise.py][INFO] --> pprocessing w/ percentiles (1st, 99th): [0.0, 5.670000076293945],  (min, max): [0.0, 3.23434630590838]


#### `Get Metrics`

In [7]:
# instantiate metrics.
lpips = LPIPS(reduction="mean", normalize=True).to(
    device
)  # set to True so that the function normalises to [-1, 1].
mse = MSELoss(reduction="mean")
csi_nodes = [2, 10, 18]
# need to get the nodes to the same scale as the data. See NB:  imerg_rainfall_classes.ipynb for rain classes + distributions.
normed_csi_nodes = pprocessor.apply_preprocessing(np.array(csi_nodes))
csi2 = CriticalSuccessIndex(threshold=normed_csi_nodes[0]).to(device)
csi10 = CriticalSuccessIndex(threshold=normed_csi_nodes[1]).to(device)
csi18 = CriticalSuccessIndex(threshold=normed_csi_nodes[-1]).to(device)

#### `Evaluation Metrics (entire predictions)`

In [8]:
# ** create the eval dataloader **
eval_loader, _ = create_eval_loader(
    data_loader=test_loader, horizon=8, input_sequence_length=4, img_dims=(128, 128)
)

** eval loader (INFO) **
Num samples = 1228 w/ dims: torch.Size([12, 1, 128, 128])



In [9]:
# ConvLSTM params (make sure that they match up with the model checkpoint).
KERNEL_SIZE = (5, 5)
INPUT_DIMS = (1, 128, 128)  # C, H, W
OUTPUT_CHANNELS = 1
HIDDEN_CHANNELS = [128, 128]
NUM_LAYERS = 2
CELL_DROPOUT = 0.15
OUTPUT_ACTIVATION = nn.Tanh()

In [12]:
# ** evaluate 101 **
eval_metrics = {}
for ckpt_id in [
    "convlstm-abcd1234",  # original model ckpt: conv_lstm_hc_128128_ks_5_oa_Tanh.pt
    "convlstm-a8kwo8jx",
]:
    # create the model ckpt path.
    ckpt_id_path = Path(os.path.join(CKPT_BASE_PATH, "", ckpt_id, "checkpoints", f"{ckpt_id}.pt"))
    print(f"Loading model ckpt from {ckpt_id_path}.")

    # instantiate a new ConvLSTM model.
    model = ConvLSTMModel(
        input_sequence_length=INPUT_SEQUENCE_LENGTH,
        output_sequence_length=OUTPUT_SEQUENCE_LENGTH,
        input_dims=INPUT_DIMS,
        hidden_channels=HIDDEN_CHANNELS,
        output_channels=OUTPUT_CHANNELS,
        num_layers=NUM_LAYERS,
        kernel_size=KERNEL_SIZE,
        output_activation=OUTPUT_ACTIVATION,
        apply_batchnorm=True,
        cell_dropout=CELL_DROPOUT,
        bias=True,
        device=device,
    )
    model = model.to(device)

    # load in the checkpoint + set to eval() mode.
    model.load_state_dict(
        state_dict=torch.load(ckpt_id_path, map_location=torch.device(device))["model_state_dict"]
    )
    model.eval()

    # ** get preds / target pairs **
    # loop through the custom eval_loader and get the predictions and targets for each X, target pair.
    # at the end of this loop, you have a results list that contains [target, predictions] pairs.
    with torch.no_grad():
        results = []
        for e, (X, target) in tqdm(
            enumerate(eval_loader), total=len(eval_loader), desc=f"Evaluating model {ckpt_id}"
        ):  # enumerate(eval_loader):
            predictions = {}
            _input = X.clone().unsqueeze(0).to(device)
            for t in range(target.size(0)):
                pred = model(_input)  # predict t+1
                if isinstance(model.output_activation, nn.Tanh):
                    pred = transform_minus1_1_to_0_1(pred)

                # add t+i to the predictions.
                predictions[f"t{t+1}"] = pred.squeeze(0)
                # update the inputs with the last pred (auto-regressive rollout)
                _input = torch.concat([_input[:, 1:, ...], pred], dim=1)

            results.append([target, predictions])

        # ** calculate metrics for each preds / target pair **
        # reset metrics for eack ckpt id.
        # overall metrics for target and prediction.
        l2_score = 0
        lpips_score = 0
        csi2_score = 0  # low rain.
        csi10_score = 0  # mid rain.
        csi18_score = 0  # heavy rain.
        for targets, predictions in results:
            # concat to get entire sequence.
            pred_seq = torch.cat([v for _, v in predictions.items()], dim=0)

            # get metrics.
            l2_score += mse(pred_seq.to(device), targets.to(device))
            # lpips score. Inputs need to have 3 channels.
            lpips_score += lpips(
                torch.clamp(pred_seq.expand(-1, 3, -1, -1), 0, 1).to(device),
                torch.clamp(targets.expand(-1, 3, -1, -1), 0, 1).to(device),
            )
            # csi score at different thresholds.
            csi2_score += csi2(pred_seq.to(device), targets.to(device))
            csi10_score += csi10(pred_seq.to(device), targets.to(device))
            csi18_score += csi18(pred_seq.to(device), targets.to(device))

        eval_metrics[ckpt_id] = {
            "MSE": l2_score.item() / len(eval_loader),
            "lpips": lpips_score.item() / len(eval_loader),
            "csi2": csi2_score.item() / len(eval_loader),
            "csi10": csi10_score.item() / len(eval_loader),
            "csi18": csi18_score.item() / len(eval_loader),
        }

# create df, format it and export it to a .csv.
df_results = pd.DataFrame(eval_metrics).T
df_results[["MSE", "lpips", "csi2", "csi10", "csi18"]]

Loading model ckpt from /teamspace/studios/this_studio/irp-ds423/rainnow/results/convlstm-abcd1234/checkpoints/convlstm-abcd1234.pt.


Evaluating model convlstm-abcd1234:   0%|          | 0/1228 [00:00<?, ?it/s]

Evaluating model convlstm-abcd1234: 100%|██████████| 1228/1228 [03:28<00:00,  5.89it/s]


Loading model ckpt from /teamspace/studios/this_studio/irp-ds423/rainnow/results/convlstm-a8kwo8jx/checkpoints/convlstm-a8kwo8jx.pt.


Evaluating model convlstm-a8kwo8jx: 100%|██████████| 1228/1228 [03:33<00:00,  5.74it/s]


Unnamed: 0,MSE,lpips,csi2,csi10,csi18
convlstm-abcd1234,0.008475,0.271863,0.13965,0.027456,0.010403
convlstm-a8kwo8jx,0.002414,0.38815,0.051441,0.045088,0.025306


#### `Eval Metrics (CSI / t)` 

In [13]:
eval_metrics_per_t = {}
for ckpt_id in [
    "convlstm-abcd1234",  # original model ckpt: conv_lstm_hc_128128_ks_5_oa_Tanh.pt
    "convlstm-a8kwo8jx",
]:
    # create the model ckpt path.
    ckpt_id_path = Path(os.path.join(CKPT_BASE_PATH, "", ckpt_id, "checkpoints", f"{ckpt_id}.pt"))
    print(f"Loading model ckpt from {ckpt_id_path}.")

    # instantiate a new ConvLSTM model.
    model = ConvLSTMModel(
        input_sequence_length=INPUT_SEQUENCE_LENGTH,
        output_sequence_length=OUTPUT_SEQUENCE_LENGTH,
        input_dims=INPUT_DIMS,
        hidden_channels=HIDDEN_CHANNELS,
        output_channels=OUTPUT_CHANNELS,
        num_layers=NUM_LAYERS,
        kernel_size=KERNEL_SIZE,
        output_activation=OUTPUT_ACTIVATION,
        apply_batchnorm=True,
        cell_dropout=CELL_DROPOUT,
        bias=True,
        device=device,
    )
    model = model.to(device)

    # load in the checkpoint + set to eval() mode.
    model.load_state_dict(
        state_dict=torch.load(ckpt_id_path, map_location=torch.device(device))["model_state_dict"]
    )
    model.eval()

    # ** get preds / target pairs **
    # loop through the custom eval_loader and get the predictions and targets for each X, target pair.
    # at the end of this loop, you have a results list that contains [target, predictions] pairs.
    with torch.no_grad():
        results = []
        for e, (X, target) in tqdm(
            enumerate(eval_loader), total=len(eval_loader), desc=f"Evaluating model {ckpt_id}"
        ):  # enumerate(eval_loader):
            predictions = {}
            _input = X.clone().unsqueeze(0).to(device)
            for t in range(target.size(0)):
                pred = model(_input)  # predict t+1
                if isinstance(model.output_activation, nn.Tanh):
                    pred = transform_minus1_1_to_0_1(pred)

                # add t+i to the predictions.
                predictions[f"t{t+1}"] = pred.squeeze(0)
                # update the inputs with the last pred (auto-regressive rollout)
                _input = torch.concat([_input[:, 1:, ...], pred], dim=1)

            results.append([target, predictions])

        # create csi stores.
        csi2_score_t = torch.zeros(target.size(0)).to(device)
        csi10_score_t = torch.zeros(target.size(0)).to(device)
        csi18_score_t = torch.zeros(target.size(0)).to(device)

        # perceptual loss scores.
        lpips_score_t = torch.zeros(target.size(0)).to(device)

        for targets, predictions in results:
            for e, (k, v) in enumerate(predictions.items()):
                # loop through all the ts and compute the relevant CSI scores.
                csi2_score_t[e] += csi2(targets[e].to(device), v[0, ...].to(device))
                csi10_score_t[e] += csi10(targets[e].to(device), v[0, ...].to(device))
                csi18_score_t[e] += csi18(targets[e].to(device), v[0, ...].to(device))

                lpips_score_t[e] += lpips(
                    torch.clamp(targets[e].expand(1, 3, -1, -1), 0, 1).to(device),
                    torch.clamp(v[0, ...].expand(1, 3, -1, -1), 0, 1).to(device),
                )

        # normalise the scores.
        eval_metrics_per_t[ckpt_id] = {
            "csi2_t": csi2_score_t / len(eval_loader),
            "csi10_t": csi10_score_t / len(eval_loader),
            "csi18_t": csi18_score_t / len(eval_loader),
            "lpips_t": lpips_score_t / len(eval_loader),
        }

Loading model ckpt from /teamspace/studios/this_studio/irp-ds423/rainnow/results/convlstm-abcd1234/checkpoints/convlstm-abcd1234.pt.


Evaluating model convlstm-abcd1234:   0%|          | 2/1228 [00:00<03:00,  6.78it/s]

Evaluating model convlstm-abcd1234: 100%|██████████| 1228/1228 [03:30<00:00,  5.83it/s]


Loading model ckpt from /teamspace/studios/this_studio/irp-ds423/rainnow/results/convlstm-a8kwo8jx/checkpoints/convlstm-a8kwo8jx.pt.


Evaluating model convlstm-a8kwo8jx: 100%|██████████| 1228/1228 [03:29<00:00,  5.87it/s]


In [14]:
horizon = 8

# get csi + lpips dfs.
df_all = {}
for k, v in eval_metrics_per_t.items():
    dfs = {}
    for metric, values in v.items():
        dfs[metric] = pd.DataFrame(
            data=[[i.item() for i in values]], columns=[f"t{i+1}" for i in range(len(values))]
        )
    df_all[k] = dfs

df_csi2 = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])
df_csi10 = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])
df_csi18 = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])
df_lpips = pd.DataFrame(columns=[f"t{i+1}" for i in range(horizon)])

for model_name, metrics in df_all.items():
    model_name_clean = model_name.rsplit(".", 1)[0]

    csi2_values = metrics["csi2_t"].iloc[0].values
    csi10_values = metrics["csi10_t"].iloc[0].values
    csi18_values = metrics["csi18_t"].iloc[0].values
    lpips_values = metrics["lpips_t"].iloc[0].values

    df_csi2.loc[model_name_clean] = csi2_values
    df_csi10.loc[model_name_clean] = csi10_values
    df_csi18.loc[model_name_clean] = csi18_values
    df_lpips.loc[model_name_clean] = lpips_values

In [15]:
df_csi2

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
convlstm-abcd1234,0.410601,0.256862,0.177228,0.132327,0.10248,0.089132,0.079248,0.07024
convlstm-a8kwo8jx,0.18438,0.080505,0.047988,0.03193,0.020539,0.014534,0.010713,0.007031


In [16]:
df_csi10

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
convlstm-abcd1234,0.197339,0.085128,0.047181,0.030801,0.021636,0.016262,0.01378,0.011649
convlstm-a8kwo8jx,0.157186,0.076,0.040655,0.022134,0.011795,0.007272,0.005036,0.002909


In [17]:
df_csi18

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
convlstm-abcd1234,0.120482,0.045346,0.020362,0.012246,0.008627,0.006378,0.00501,0.00413
convlstm-a8kwo8jx,0.09231,0.042377,0.022197,0.01164,0.006171,0.004029,0.002798,0.001618


In [18]:
df_lpips

Unnamed: 0,t1,t2,t3,t4,t5,t6,t7,t8
convlstm-abcd1234,0.112273,0.179575,0.228382,0.269373,0.305903,0.336107,0.362349,0.381351
convlstm-a8kwo8jx,0.314548,0.371742,0.391685,0.397862,0.402602,0.406572,0.41063,0.411008


### END OF SCRIPT.