## `ConvLSTM` (`one-shot`) Inference.
* The aim of this notebook is to run inference on trained `ConvLSTM()` models that predict the entire target sequence in one pass.

* The decision was made to evaluate a `next-step` **ConvLSTM** against **DYffusion** for better alignment.

In [None]:
import os
from pathlib import Path
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from matplotlib.colors import ListedColormap
from scipy import io
from torch.utils.data import DataLoader

from rainnow.src.convlstm_trainer import create_eval_loader, save_checkpoint, train, validate
from rainnow.src.datasets import IMERGDataset
from rainnow.src.models.conv_lstm import ConvLSTMModel
from rainnow.src.normalise import PreProcess
from rainnow.src.plotting import plot_predicted_sequence, plot_training_val_loss
from rainnow.src.utilities.loading import load_imerg_datamodule_from_config
from rainnow.src.utilities.utils import get_device, transform_minus1_1_to_0_1

#### `helpers`

In [None]:
# ** DIR helpers **
BASE_PATH = "/teamspace/studios/this_studio"

CKPT_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/results/"
CONFIGS_BASE_PATH = f"{BASE_PATH}/DYffcast/rainnow/src/dyffusion/configs/"

CKPT_DIR = "checkpoints"
CKPT_CFG_NAME = "hparams.yaml"
DATAMODULE_CONFIG_NAME = "imerg_precipitation.yaml"
# whether or not to get last.ckpt or to get the "best model" ckpt (the other one in the folder).
GET_LAST = False

# ** Dataloader Params **
BATCH_SIZE = 4
NUM_WORKERS = 0

INPUT_SEQUENCE_LENGTH = 8
OUTPUT_SEQUENCE_LENGTH = 8

# ** plotting helpers **
# cmap = io.loadmat("../../src/utilities/cmaps/colormap.mat")
cmap = io.loadmat(f"{BASE_PATH}/DYffcast/rainnow/src/utilities/cmaps/colormap.mat")
rain_cmap = ListedColormap(cmap["Cmap_rain"])
global_params = {"font.size": 8}  # , "font.family": "Times New Roman"}
plt_params = {"wspace": 0.1, "hspace": 0.15}
ylabel_params = {"ha": "right", "va": "bottom", "labelpad": 1, "fontsize": 7.5}

# ** get device **
device = get_device()

#### `Datasets & Dataloaders`

In [None]:
datamodule = load_imerg_datamodule_from_config(
    cfg_base_path=CONFIGS_BASE_PATH,
    cfg_name=DATAMODULE_CONFIG_NAME,
    overrides={
        "boxes": ["1,0"],
        "window": 1,
        "horizon": 15,
        "prediction_horizon": 15,
        "sequence_dt": 1,
    },
)

# get the .data_<split>
datamodule.setup("validate")
datamodule.setup("test")
datamodule.setup("predict")

In [None]:
# create the datasets.
val_dataset = IMERGDataset(
    datamodule, "validate", sequence_length=INPUT_SEQUENCE_LENGTH, target_length=OUTPUT_SEQUENCE_LENGTH
)
test_dataset = IMERGDataset(
    datamodule, "test", sequence_length=INPUT_SEQUENCE_LENGTH, target_length=OUTPUT_SEQUENCE_LENGTH
)
predict_dataset = IMERGDataset(
    datamodule, "predict", sequence_length=INPUT_SEQUENCE_LENGTH, target_length=OUTPUT_SEQUENCE_LENGTH
)

# instantiate the dataloaders.
val_loader = DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False
)
predict_loader = DataLoader(
    dataset=predict_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False
)

In [None]:
# ** instantiate the preprocesser obj **
pprocessor = PreProcess(
    percentiles=datamodule.normalization_hparams["percentiles"],
    minmax=datamodule.normalization_hparams["min_max"],
)

#### `Instantiate a ConvLSTM()`

In [None]:
ckpt_ids = {
    # entire sequence in one go ConvLSTM.
    "convlstm-r9m0o2ho": "I8:T8 | hs=(64, 64), ks=(5, 5), dp=.3 [40E, lr=3e-4]",
    "convlstm-jtwm5nt0": "I8:T8 | hs=(64, 64, 64), ks=(5, 5), dp=.3 [40E, lr=3e-4]",
    "convlstm-evzddqer": "I8:T8 | hs=(32, 32, 32), ks=(5, 5), dp=.3 [40E, lr=1e-4]",
    "convlstm-5tbjfvkc": "I8:T8 | hs=(128, 64), ks=(3, 3), dp=.15 [20E, lr=3e-4]",
    "convlstm-3bg6j99s": "I8:T8 | hs=(128, 64), ks=(5, 5), dp=.3 [20E, lr=3e-4]",
    "convlstm-ivqxk14e": "I8:T8 | hs=(128, 64), ks=(5, 5), dp=.15 [20E, lr=3e-4]",
}

In [None]:
# ** load in checkpoint **
ckpt_id = "convlstm-jtwm5nt0"
model_save_path = Path(
    os.path.join(CKPT_BASE_PATH, "convlstm_experiments", ckpt_id, "checkpoints", f"{ckpt_id}.pt")
)

In [None]:
# ** instantiate a new model **
KERNEL_SIZE = (5, 5)
INPUT_DIMS = (1, 128, 128)  # C, H, W
OUTPUT_CHANNELS = 1
HIDDEN_CHANNELS = [64, 64, 64]
NUM_LAYERS = 3
CELL_DROPOUT = 0.15
OUTPUT_ACTIVATION = nn.Sigmoid()

model = ConvLSTMModel(
    input_sequence_length=INPUT_SEQUENCE_LENGTH,
    output_sequence_length=OUTPUT_SEQUENCE_LENGTH,
    input_dims=INPUT_DIMS,
    hidden_channels=HIDDEN_CHANNELS,
    output_channels=OUTPUT_CHANNELS,
    num_layers=NUM_LAYERS,
    kernel_size=KERNEL_SIZE,
    output_activation=OUTPUT_ACTIVATION,
    apply_batchnorm=True,
    cell_dropout=CELL_DROPOUT,
    bias=True,
    device=device,
)

model = model.to(device)

# load in the checkpoint.
model.load_state_dict(
    state_dict=torch.load(model_save_path, map_location=torch.device(device))["model_state_dict"]
)
# set model into eval mode.
model.eval()

#### `Get inputs, X and predict.`

In [None]:
iter_loader = iter(predict_loader)
X, target = next(iter_loader)

# get raw scale.
X_reversed = pprocessor.reverse_processing(X).cpu().detach()
target_reversed = pprocessor.reverse_processing(target).cpu().detach()

In [None]:
# ** make prediction **
pred = model(X.to(device))
if isinstance(model.output_activation, nn.Tanh):
    yhat = transform_minus1_1_to_0_1(pred)

pred_reversed = pprocessor.reverse_processing(pred)
pred_reversed = pred_reversed.cpu().detach()

In [None]:
# ** plot params **
b = 3
plot_params = {"cmap": rain_cmap, "vmin": 0.5, "vmax": 8}
figsize = (20, 3)

In [None]:
plot_predicted_sequence(
    X=X_reversed,
    target=target_reversed,
    pred=pred_reversed,
    batch_num=b,
    plot_params=plot_params,
    figsize=figsize,
)

### END OF SCRIPT.