# ConvLSTM (CLSTM) Notebook

This notebook packages:
- L1 + SSIM loss (external library)
- Training loop with loss history
- Loss-epoch plot
- Prediction visualization
- Hooks to plug in your own `Task1Dataset`

**Required variable names:** `train_ds, val_ds, train_loader, val_loader`

In [1]:
# Cell 1: basics
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [2]:
# Cell 2: import model code (keep repo structure)
# Assumes you are running in the repo root (same folder as model.py, net_params.py, etc.)
from model import EncoderDecoderConvLSTM
from net_params import convlstm_encoder_params, convlstm_decoder_params

model = EncoderDecoderConvLSTM(
    convlstm_encoder_params,
    convlstm_decoder_params,
).to(device)

print(model.__class__.__name__)


ModuleNotFoundError: No module named 'model'

In [None]:
# Cell 3: SSIM external lib
# If not installed, uncomment the next line:
# !pip -q install pytorch-msssim

from pytorch_msssim import ssim


In [None]:
# Cell 4: L1 + SSIM Loss (works with z-score by using dynamic data_range)
class L1SSIMLoss(nn.Module):
    def __init__(self, alpha=0.8, ssim_win_size=11, eps=1e-6):
        super().__init__()
        self.alpha = alpha
        self.l1 = nn.L1Loss()
        self.win_size = ssim_win_size
        self.eps = eps

    def forward(self, pred, target):
        # pred/target: (B,T,C,H,W) or (B,C,H,W)
        l1_loss = self.l1(pred, target)

        # Flatten time into batch for SSIM
        if pred.dim() == 5:
            B, T, C, H, W = pred.shape
            pred_ = pred.reshape(B * T, C, H, W)
            tgt_  = target.reshape(B * T, C, H, W)
        else:
            pred_ = pred
            tgt_  = target

        # Dynamic data_range for z-score
        vmin = torch.min(torch.min(pred_.detach()), torch.min(tgt_.detach()))
        vmax = torch.max(torch.max(pred_.detach()), torch.max(tgt_.detach()))
        data_range = (vmax - vmin).clamp_min(self.eps)

        ssim_val = ssim(
            pred_, tgt_,
            data_range=data_range,
            size_average=True,
            win_size=self.win_size
        )
        ssim_loss = 1.0 - ssim_val

        return self.alpha * l1_loss + (1.0 - self.alpha) * ssim_loss


## Plug in your dataset

Choose one of the following.

### Option A: Your Task1Dataset (recommended)

⚠️ You previously hit: `Task1Dataset.__init__() got multiple values for argument 'log_min'`.
That usually means `log_min` is already the 3rd positional parameter, so pass `log_min/log_max` positionally (do **not** repeat keyword).

### Option B: MovingMNIST (repo demo)

If you just want to test the pipeline first, you can use MovingMNIST.

In [None]:
from huggingface_hub import hf_hub_download
hf_hub_download(
    repo_id="benmoseley/ese-dl-2025-26-group-project",
    filename="train.h5",
    repo_type="dataset",
    local_dir="data",
)
hf_hub_download(
    repo_id="benmoseley/ese-dl-2025-26-group-project",
    filename="events.csv",
    repo_type="dataset",
    local_dir="data",
)


In [None]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader
print("CWD:", os.getcwd())
print("data/ contains:", os.listdir("data"))

H5_PATH = "data/train.h5"

In [None]:
# Inspect file structure (storm IDs)
with h5py.File(H5_PATH, "r") as f:
    storm_ids_all = sorted(list(f.keys()))
    print("Number of storms:", len(storm_ids_all))
    print("First 10 storm IDs:", storm_ids_all[:10])

In [None]:
#  Helpers: (H, W, T) -> (T, H, W) and Task-1 slicing
def vil_to_THW(vil_hwT: np.ndarray) -> np.ndarray:
    """
    Convert raw VIL from (H, W, T) to (T, H, W).
    The dataset stores VIL as (384, 384, 36) = (H, W, T).
    We use time-first everywhere: (T, H, W).
    """
    return np.transpose(vil_hwT, (2, 0, 1))

In [None]:
def extract_task1_pair_THW(vil_THW: np.ndarray, T_in: int = 12, T_out: int = 12):
    """
    Task 1:
      Input  = first 12 frames
      Target = next 12 frames

    Parameters
    ----------
    vil_THW : np.ndarray, shape (T_total, H, W)  (T_total=36)
    T_in    : int
    T_out   : int

    Returns
    -------
    X_THW : np.ndarray, shape (T_in,  H, W)
    Y_THW : np.ndarray, shape (T_out, H, W)
    """
    X = vil_THW[:T_in]
    Y = vil_THW[T_in:T_in + T_out]
    return X, Y

In [None]:
#Train/val split by storm ID (prevents leakage)
ids_train, ids_val = train_test_split(
    storm_ids_all,
    test_size=0.2,
    random_state=42
)
print("Train storms:", len(ids_train), "Val storms:", len(ids_val))


In [None]:
#DataLoaders
#    Tip: 384x384 is big. Start small (batch_size=4/8). Increase later if stable.
# -------------------------
use_pin_memory = torch.cuda.is_available()

train_ds = Task1StormDataset(H5_PATH, ids_train, train_mean, train_std, T_in=12, T_out=12)
val_ds   = Task1StormDataset(H5_PATH, ids_val,   train_mean, train_std, T_in=12, T_out=12)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=0, pin_memory=use_pin_memory)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0, pin_memory=use_pin_memory)



In [None]:
#PyTorch Dataset (VIL only, Z-normalized) — outputs (T,1,H,W)
#    Key points:
#      - reads only (T_in + T_out) frames (24) to save RAM
#      - normalizes BOTH X and Y using TRAIN mean/std
#      - lazy-opens HDF5 file once (faster than opening every __getitem__)
# -------------------------
class Task1StormDataset(Dataset):
    """
    Returns (X, Y, storm_id) for Task 1.

    Shapes:
      X: (T_in,  1, H, W)
      Y: (T_out, 1, H, W)
    """
    def __init__(self, h5_path: str, storm_ids, mean: float, std: float, T_in: int = 12, T_out: int = 12):
        self.h5_path = h5_path
        self.storm_ids = list(storm_ids)

        self.mean = float(mean)
        self.std = float(std) + 1e-6  # avoid divide-by-zero

        self.T_in = int(T_in)
        self.T_out = int(T_out)

        self._file = None  # lazy open

    def __len__(self):
        return len(self.storm_ids)

    def _get_file(self):
        if self._file is None:
            self._file = h5py.File(self.h5_path, "r")
        return self._file

    def __getitem__(self, idx):
        sid = self.storm_ids[idx]
        f = self._get_file()

        # load ONLY frames we need: first (T_in + T_out)
        T_needed = self.T_in + self.T_out
        vil_hwT = f[f"{sid}/vil"][:, :, :T_needed].astype(np.float32)  # (H,W,24)

        vil_THW = np.transpose(vil_hwT, (2, 0, 1))  # (24,H,W)

        X_THW = vil_THW[:self.T_in]                           # (12,H,W)
        Y_THW = vil_THW[self.T_in:self.T_in + self.T_out]     # (12,H,W)

        # Z-normalize using TRAIN stats
        X_THW = (X_THW - self.mean) / self.std
        Y_THW = (Y_THW - self.mean) / self.std

        # Convert to torch tensors and add channel dim -> (T,1,H,W)
        X = torch.from_numpy(X_THW).unsqueeze(1)  # (12,1,H,W)
        Y = torch.from_numpy(Y_THW).unsqueeze(1)  # (12,1,H,W)

        return X, Y, sid

In [None]:
# Cell 5B: MovingMNIST demo (optional)
# Comment out if you use Task1Dataset

# from data.mm import MovingMNIST
# train_ds = MovingMNIST(is_train=True, root="data/", n_frames_input=10, n_frames_output=10, num_objects=[3])
# val_ds   = MovingMNIST(is_train=False, root="data/", n_frames_input=10, n_frames_output=10, num_objects=[3])


In [None]:
# Cell 6: DataLoader (keep variable names)
train_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
)

val_loader = DataLoader(
    val_ds,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
)
print("train batches:", len(train_loader), "val batches:", len(val_loader))


In [None]:
# Cell 7: Loss & optimizer (L1 + SSIM)
loss_fn = L1SSIMLoss(alpha=0.8, ssim_win_size=11)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
# Cell 8: Train / Eval loops (history for plots)
@torch.no_grad()
def _ensure_tensor(x):
    # Some repos return (pred, states) etc. This keeps only the tensor.
    if isinstance(x, (list, tuple)):
        return x[0]
    return x

def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total = 0.0
    for inputs, targets in loader:
        inputs  = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = _ensure_tensor(model(inputs))
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        total += loss.item()
    return total / max(1, len(loader))

@torch.no_grad()
def eval_one_epoch(model, loader, loss_fn, device):
    model.eval()
    total = 0.0
    for inputs, targets in loader:
        inputs  = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        outputs = _ensure_tensor(model(inputs))
        loss = loss_fn(outputs, targets)
        total += loss.item()
    return total / max(1, len(loader))


In [None]:
# Cell 9: Fit
history = {"train_loss": [], "val_loss": []}
num_epochs = 20

best_val = float("inf")
save_path = "checkpoints/best_convlstm_ed.pt"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

for epoch in range(1, num_epochs + 1):
    tr = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    va = eval_one_epoch(model, val_loader, loss_fn, device)

    history["train_loss"].append(tr)
    history["val_loss"].append(va)

    if va < best_val:
        best_val = va
        torch.save(model.state_dict(), save_path)

    print(f"Epoch {epoch:03d}/{num_epochs:03d} | train={tr:.6f} | val={va:.6f} | best_val={best_val:.6f}")


In [None]:
# Cell 10: loss-epoch plot
import matplotlib.pyplot as plt

plt.figure()
plt.plot(history["train_loss"], label="train")
plt.plot(history["val_loss"], label="val")
plt.xlabel("Epoch")
plt.ylabel("Loss (L1 + SSIM)")
plt.legend()
plt.show()


In [None]:
# Cell 11: Prediction visualization (GT vs Pred frames)
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def visualize_prediction(model, loader, device, channel_idx=0, max_frames=12):
    model.eval()
    inputs, targets = next(iter(loader))
    inputs  = inputs.to(device)
    targets = targets.to(device)

    preds = _ensure_tensor(model(inputs))

    # pick first sample: (T,C,H,W)
    pred0 = preds[0]
    tgt0  = targets[0]

    T = min(int(pred0.shape[0]), int(max_frames))
    C = int(pred0.shape[1])
    ch = min(int(channel_idx), C - 1)

    pred_seq = pred0[:T, ch].detach().cpu().numpy()
    tgt_seq  = tgt0[:T,  ch].detach().cpu().numpy()

    plt.figure(figsize=(2*T, 4))
    for t in range(T):
        plt.subplot(2, T, t + 1)
        plt.imshow(tgt_seq[t])
        plt.axis("off")
        plt.title(f"GT t={t}")

        plt.subplot(2, T, T + t + 1)
        plt.imshow(pred_seq[t])
        plt.axis("off")
        plt.title(f"Pred t={t}")

    plt.tight_layout()
    plt.show()

# For your dataset: if C==1, channel_idx must be 0
visualize_prediction(model, val_loader, device, channel_idx=0, max_frames=12)
