## `ConvLSTM` Exploration.
* The aim of this notebook is to **train** and **validate** a `ConvLSTM()`,  exploring different setups.

In [None]:
import os
from pathlib import Path
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from livelossplot import PlotLosses
from matplotlib.colors import ListedColormap
from scipy import io
from torch.utils.data import DataLoader
from tqdm import tqdm

from rainnow.src.convlstm_trainer import save_checkpoint, train, validate, create_eval_loader
from rainnow.src.datasets import IMERGDataset
from rainnow.src.plotting import plot_training_val_loss, plot_predicted_sequence
from rainnow.src.loss import CBLoss, LPIPSMSELoss
from rainnow.src.models.conv_lstm import ConvLSTMModel
from rainnow.src.normalise import PreProcess
from rainnow.src.utilities.loading import load_imerg_datamodule_from_config
from rainnow.src.utilities.utils import (
    get_device,
    transform_0_1_to_minus1_1,
    transform_minus1_1_to_0_1,
)
from rainnow.src.plotting import plot_a_sequence

#### `helpers.`

In [None]:
# ** DIR helpers **
BASE_PATH = "/teamspace/studios/this_studio"

CKPT_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/results/"
CONFIGS_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/src/dyffusion/configs/"

CKPT_DIR = "checkpoints"
CKPT_CFG_NAME = "hparams.yaml"
DATAMODULE_CONFIG_NAME = "imerg_precipitation.yaml"
# whether or not to get last.ckpt or to get the "best model" ckpt (the other one in the folder).
GET_LAST = False

# ** Dataset Params **
REVERSE_PROBABILITY = 0.2  # X% chance to flip the sequence during training.

# ** Dataloader Params **
BATCH_SIZE = 8
NUM_WORKERS = 0

INPUT_SEQUENCE_LENGTH = 4
OUTPUT_SEQUENCE_LENGTH = 1

# ** plotting helpers **
# cmap = io.loadmat("../../src/utilities/cmaps/colormap.mat")
cmap = io.loadmat(f"{BASE_PATH}/DYffcast/rainnow/src/utilities/cmaps/colormap.mat")
rain_cmap = ListedColormap(cmap["Cmap_rain"])
global_params = {"font.size": 8}  # , "font.family": "Times New Roman"}
plt_params = {"wspace": 0.1, "hspace": 0.15}
ylabel_params = {"ha": "right", "va": "bottom", "labelpad": 1, "fontsize": 7.5}

# ** get device **
device = get_device()

#### `Datasets & DataLoaders`

In [None]:
datamodule = load_imerg_datamodule_from_config(
    cfg_base_path=CONFIGS_BASE_PATH,
    cfg_name=DATAMODULE_CONFIG_NAME,
    overrides={
        # "boxes": ["1,0"],
        "boxes": ["0,0", "1,0", "2,0", "2,1"],
        "window": 1,
        "horizon": 8,
        "prediction_horizon": 8,
        "sequence_dt": 1,
    },
)

# get the .data_<split>
datamodule.setup("train")
datamodule.setup("validate")

In [None]:
# create the datasets.
train_dataset = IMERGDataset(
    datamodule,
    "train",
    sequence_length=INPUT_SEQUENCE_LENGTH,
    target_length=OUTPUT_SEQUENCE_LENGTH,
)
val_dataset = IMERGDataset(
    datamodule, "validate", sequence_length=INPUT_SEQUENCE_LENGTH, target_length=OUTPUT_SEQUENCE_LENGTH
)

# instantiate the dataloaders.
train_loader = DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)
val_loader = DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False
)

In [None]:
# test train dataset.
X_dummy, y_dummy = train_dataset.__getitem__(0)
X_dummy.size(), y_dummy.size()

assert X_dummy.size() == torch.Size([INPUT_SEQUENCE_LENGTH, 1, 128, 128])
assert y_dummy.size() == torch.Size([OUTPUT_SEQUENCE_LENGTH, 1, 128, 128])

In [None]:
# plot a batch.
dummy_batch = next(iter(train_loader))
dummy_inputs, dummy_target = dummy_batch
dummy_inputs.size(), dummy_target.size()

In [None]:
plot_a_sequence(
    X=torch.cat([dummy_inputs, dummy_target], dim=1),
    b=0,
    global_params=global_params,
    plot_params={"cmap": rain_cmap},
    layout_params=plt_params,
)
plot_a_sequence(
    X=torch.cat([dummy_inputs, dummy_target], dim=1),
    b=1,
    global_params=global_params,
    plot_params={"cmap": rain_cmap},
    layout_params=plt_params,
)

In [None]:
# ** instantiate the preprocesser obj **
pprocessor = PreProcess(
    percentiles=datamodule.normalization_hparams["percentiles"],
    minmax=datamodule.normalization_hparams["min_max"],
)

####  `Training and Val:`

##### `Loss Function(s)`

In [None]:
## ** loss function **
# criterion = nn.MSELoss(reduction='mean')
# criterion = nn.L1Loss(reduction='mean')
# criterion = nn.BCELoss(reduction="mean")
# criterion = CBLoss(beta=1.0, data_preprocessor_obj=pprocessor)
criterion = LPIPSMSELoss(
    alpha=0.6,
    model_name="alex",  # trains better with 'alex' - https://github.com/richzhang/PerceptualSimilarity.
    reduction="mean",
    gamma=1.0,
    mse_type="cb",
    **{"beta": 1, "data_preprocessor_obj": pprocessor},
).to(device)

##### `Instantiate a ConvLSTM()`

In [None]:
# ** instantiate a new model **
KERNEL_SIZE = (5, 5)
INPUT_DIMS = (1, 128, 128)  # C, H, W
OUTPUT_CHANNELS = 1
HIDDEN_CHANNELS = [128, 128]
NUM_LAYERS = 2
CELL_DROPOUT = 0.4  # 0.2
OUTPUT_ACTIVATION = nn.Tanh()  # nn.Tanh()  # nn.Sigmoid() # nn.Tanh() # nn.Sigmoid()

model = ConvLSTMModel(
    input_sequence_length=INPUT_SEQUENCE_LENGTH,
    output_sequence_length=OUTPUT_SEQUENCE_LENGTH,
    input_dims=INPUT_DIMS,
    hidden_channels=HIDDEN_CHANNELS,
    output_channels=OUTPUT_CHANNELS,
    num_layers=NUM_LAYERS,
    kernel_size=KERNEL_SIZE,
    output_activation=OUTPUT_ACTIVATION,
    apply_batchnorm=True,
    cell_dropout=CELL_DROPOUT,
    bias=True,
    device=device,
)

model = model.to(device)
model

In [None]:
# load in a checkpoint.
model_save_path = "convlstm-abcd1234.pt"  # (old best) LCB
model.load_state_dict(
    state_dict=torch.load(model_save_path, map_location=torch.device(device))["model_state_dict"]
)
model.train()

In [None]:
## ** training params 101 **
n_epochs = 12
optimizer = torch.optim.AdamW(model.parameters(), lr=7e-5, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=5)
use_liveloss = True

# model ckpt save path.
model_save_path = f"conv_lstm_hc_{''.join([str(i) for i in HIDDEN_CHANNELS])}_ks_{KERNEL_SIZE[0]}_oa_{str(OUTPUT_ACTIVATION)[:-2]}_dummy.pt"

liveloss = PlotLosses()
train_losses, val_losses = [], []
with tqdm(total=len(train_loader) * n_epochs, desc="Training", unit="batch") as tepoch:
    for i in range(n_epochs):
        logs = {}
        train_loss = train(
            model=model,
            optimizer=optimizer,
            data_loader=train_loader,
            criterion=criterion,
            tepoch=tepoch,
            curr_epoch=i,
            n_epochs=n_epochs,
            device=device,
            log_training=True,
        )
        train_losses.append(train_loss)
        logs["loss"] = train_loss

        if val_loader:
            val_loss = validate(
                model=model,
                data_loader=val_loader,
                device=device,
                criterion=criterion,
            )
            val_losses.append(val_loss)
            logs["val_loss"] = val_loss

        if use_liveloss:
            liveloss.update(logs)
            liveloss.send()

        # save a model checkpoint.
        if model_save_path:
            # save 'best' val loss checkpoint.
            if val_loss <= min(val_losses, default=float("inf")):
                save_checkpoint(
                    model_weights=model.state_dict(),
                    optimizer_info=optimizer.state_dict(),
                    model_save_path=model_save_path,
                    epoch=i,
                    train_loss=train_loss,
                    val_loss=val_loss,
                )

        if scheduler:
            scheduler.step(train_losses[-1])

    # save the last epoch.
    save_checkpoint(
        model_weights=model.state_dict(),
        optimizer_info=optimizer.state_dict(),
        model_save_path=model_save_path[:-3] + "_last.pt",
        epoch=i,
        train_loss=train_loss,
        val_loss=val_loss,
    )

#### `create Predict dataloader`

In [None]:
datamodule.setup("predict")

predict_dataset = IMERGDataset(
    datamodule, "predict", sequence_length=INPUT_SEQUENCE_LENGTH, target_length=OUTPUT_SEQUENCE_LENGTH
)
predict_loader = DataLoader(
    dataset=predict_dataset, batch_size=6, num_workers=NUM_WORKERS, shuffle=False
)

#### `Eval predictions`

In [None]:
# ** create the eval dataloader **
eval_loader, _ = create_eval_loader(
    data_loader=predict_loader, horizon=8, input_sequence_length=4, img_dims=(128, 128)
)

#### `Load in 'best' model`

In [None]:
# load in the checkpoint from in-line NB training & set to eval() mode.
model_save_path = f"conv_lstm_hc_{''.join([str(i) for i in HIDDEN_CHANNELS])}_ks_{KERNEL_SIZE[0]}_oa_{str(OUTPUT_ACTIVATION)[:-2]}_dummy.pt"

# ** BCE **
# model_save_path = "conv_lstm_k34y14il.pt"  # BCE.
# model_save_path = "convlstm-a8kwo8jx.pt"   # old.

# ** LCB **
# model_save_path = "convlstm-abcd1234.pt" # (old best) LCB
# model_save_path = "conv_lstm_f1dlb0m7.pt"  # LCB 30E LCB
# model_save_path = "conv_lstm_9h86gnt5.pt" # 10E large batch size LCB
# model_save_path = "conv_lstm_wwj6eryz.pt" # recent 10E.
# model_save_path = "conv_lstm_i93g6pzb.pt"

model.load_state_dict(
    state_dict=torch.load(model_save_path, map_location=torch.device(device))["model_state_dict"]
)
model.eval()

In [None]:
results = []
with torch.no_grad():
    for e, (X, target) in enumerate(eval_loader):
        predictions = {}
        # predict t+1
        _input = X.clone().unsqueeze(0).to(device)

        for t in range(target.size(0)):
            pred = model(_input)
            if isinstance(model.output_activation, nn.Tanh):
                pred = transform_minus1_1_to_0_1(pred)

            predictions[f"t{t+1}"] = pred.squeeze(0)

            # update the inputs with the last pred.
            _input = torch.concat([_input[:, 1:, ...], pred], dim=1)

        results.append([target, predictions])

In [None]:
c = 0
figsize = (14, 4)
plot_params = {"cmap": rain_cmap, "vmin": 0.1, "vmax": 18}

idx = 32  # 19 is the paper case study image.

targets, predictions = results[idx]
targets = targets.cpu().detach()
predictions = {k: v.cpu().detach() for k, v in predictions.items()}

# get raw scale.
target_reversed = pprocessor.reverse_processing(targets)
preds_reversed = {k: pprocessor.reverse_processing(v) for k, v in predictions.items()}

fig, axs = plt.subplots(2, targets.size(0), figsize=figsize)
plt.rcParams.update(global_params)
# plot ground truth in the 1st row.
for j in range(targets.size(0)):
    axs[0, j].imshow(targets[j, c, ...], **{"cmap": rain_cmap})
    axs[0, j].set_title(f"Ground Truth: x_t+{j}")
# plot predictions.
for e, (k, pred) in enumerate(predictions.items()):
    axs[1, e].imshow(pred[0, c, :, :], **{"cmap": rain_cmap})
    axs[1, e].set_title(f"{k}")
for ax in axs.flatten():
    ax.axis("off")
plt.tight_layout()

# raw scale.
fig, axs = plt.subplots(2, targets.size(0), figsize=figsize)
plt.rcParams.update(global_params)
# plot ground truth in the 1st row.
for j in range(targets.size(0)):
    axs[0, j].imshow(target_reversed[j, c, ...], **plot_params)
    axs[0, j].set_title(f"Ground Truth: x_t+{j}")
# plot predictions.
for e, (k, pred) in enumerate(preds_reversed.items()):
    axs[1, e].imshow(pred[0, c, :, :], **plot_params)
    axs[1, e].set_title(f"{k}")
for ax in axs.flatten():
    ax.axis("off")
plt.tight_layout()

### END OF SCRIPT.