# Training model
The goal of this notebook is to train a model and upload it to kaggle.

## Imports

In [3]:
from os.path import join
from typing import Optional

import torch
import numpy as np
from torch import nn, Tensor
from kagglehub import dataset_download
from torch.utils.data import Dataset, DataLoader

In [34]:
# TODO: Switch to TensorDataset w/ cross validation splits
class CMIDataset(Dataset):
    def __init__(self, use_agg_tof:bool, subset:Optional[int]=None, force_download=False):
        super().__init__()
        dataset_path = dataset_download("mauroabidalcarrer/prepocessed-cmi-2025", force_download=force_download)
        x_path  = join(dataset_path, "tof_meaned_X.npy" if use_agg_tof else "X.npy")
        self.x = np.load(x_path, mmap_mode="r").swapaxes(1, 2)
        self.y = np.load(join(dataset_path, "Y.npy"), mmap_mode="r")
        if subset is not None:
            self.x = self.x[:subset]
            self.y = self.y[:subset]

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx].copy(), self.y[idx].copy()
    
dataset = CMIDataset(100, force_download=False)
data_loader = DataLoader(dataset, 64, shuffle=True)

In [27]:
from itertools import pairwise

class ResidualBlock(nn.Module):
    def __init__(self, in_chns:int, out_chns:int):
        super().__init__()
        self.blocks = nn.Sequential(
            nn.Conv1d(in_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
            nn.ReLU(),
            nn.Conv1d(out_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
        )
        if in_chns == out_chns:
            self.skip_connection = nn.Identity() 
        else:
            # TODO: set bias to False ?
            self.skip_connection = nn.Sequential(
                nn.Conv1d(in_chns, out_chns, 1),
                nn.BatchNorm1d(out_chns)
            )

    def forward(self, x:Tensor) -> Tensor:
        activaition_maps = self.skip_connection(x) + self.blocks(x)
        return nn.functional.relu(activaition_maps)

class Resnet(nn.Module):
    def __init__(
            self,
            in_channels:int,
            depth:int,
            # n_res_block_per_depth:int,
            hidden_mlp_width:int,
            n_class:int,
        ):
        super().__init__()
        chs_per_depth = [in_channels * 2 ** i for i in range(depth)]
        blocks_chns_it = pairwise(chs_per_depth)
        self.res_blocks = [ResidualBlock(in_chns, out_chns) for in_chns, out_chns in blocks_chns_it]
        self.res_blocks = nn.ModuleList(self.res_blocks)
        self.mlp_head = nn.Sequential(
            nn.LazyLinear(hidden_mlp_width),
            nn.ReLU(),
            nn.Linear(hidden_mlp_width, n_class),
            nn.Softmax(),
        )
        
        
    def forward(self, x:Tensor) -> Tensor:
        activation_maps = x
        for res_block in self.res_blocks:
            activation_maps = nn.functional.max_pool1d(res_block(activation_maps), 2)
        out = activation_maps.view(activation_maps.shape[0], -1)
        out = self.mlp_head(out)
        return out

model = Resnet(17, 4, 256, 18)

In [6]:
model

Resnet(
  (res_blocks): ModuleList(
    (0): ResidualBlock(
      (blocks): Sequential(
        (0): Conv1d(17, 34, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): BatchNorm1d(34, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv1d(34, 34, kernel_size=(3,), stride=(1,), padding=(1,))
        (4): BatchNorm1d(34, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip_connection): Sequential(
        (0): Conv1d(17, 34, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(34, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResidualBlock(
      (blocks): Sequential(
        (0): Conv1d(34, 68, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): BatchNorm1d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv1d(68, 68, kernel_size=(3,), stride=(1,), padding=(1,))
        (4): BatchNorm1d(68, eps=1e-05

In [22]:
x, y = next(iter(data_loader))

In [23]:
x.shape

torch.Size([128, 17, 114])

In [28]:
y_pred = model(x)
y_pred.shape

  return self._call_impl(*args, **kwargs)


torch.Size([128, 18])

In [14]:
136 * 14

1904

In [None]:
import os
from datetime import datetime

import torch
from torch import nn
import plotly.express as px
from rich.progress import Progress, Task, track
from pandas import DataFrame as DF
from torch.utils.data import DataLoader as DL
from torch.optim.lr_scheduler import LRScheduler


def fit(epochs:int,
        model: nn.Module,
        scheduler: LRScheduler,
        optimizer: torch.optim.Optimizer,
        train_loader: DL,
        criterion: callable=nn.L1Loss(),
        evaluation_func: callable=None,
        validation_loader: DL=None,
        save_checkpoints=True,
    ) -> DF:
    # Setup
    metrics: list[dict] = []
    checkpoints_dir = os.path.join("checkpoints", mk_date_now_str())
    step = 0
    model_device = next(model.parameters()).device
    # Training loop
    with Progress() as progress:
        task: Task = progress.add_task(
            # f"epoch: {epoch + 1}, batch_loss: ...",
            total=len(train_loader) - 10,
        )
        for epoch in range(epochs):
            progress.update(
                task,
                description=f"epoch: {epoch}"
            )
            total_epoch_loss = 0
            nb_samples = 0
            for batch_idx, (x, y) in enumerate(train_loader):
                # forward
                x = x.to(model_device)
                y = y.to(model_device)
                nb_samples += len(x)
                model.train()
                optimizer.zero_grad()
                y_pred = model(x)
                loss_value = criterion(y_pred, y)
                # Verify loss value
                if torch.isnan(loss_value).any().item():
                    progress.print("Warning: Got NaN loss, something went wrong.")
                    return DF.from_records(metrics) 
                if torch.isinf(loss_value).any().item():
                    progress.print("Warning: Got infinite loss, something went wrong.")
                    return DF.from_records(metrics) 
                # TODO: Use gradient clipping?
                loss_value.backward()
                optimizer.step()
                if step > 0: # If it's not the first training step
                    # Call the scheduler step method, idk why it throws an error otherwise
                    scheduler.step()
                # metrics
                total_epoch_loss += loss_value.item()
                metrics.append({
                    "step": step,
                    "epoch": epoch,
                    "batch_train_loss": loss_value.item(),
                    "lr": optimizer.state_dict()["param_groups"][-1]["lr"],
                })
                step += 1
                progress.update(
                    task,
                    advance=1,
                    description=f"epoch: {epoch}, batch_loss: {(total_epoch_loss / (batch_idx+1)):.2f}"
                )
        # Post epoch evalution
        metrics[-1]["train_epoch_loss"] = total_epoch_loss / len(train_loader)
        if evaluation_func:
            eval_metrics = evaluation_func(model, criterion, validation_loader)
            progress.print("validation loss:", eval_metrics["validation_loss"])
            metrics[-1].update(eval_metrics)
        # Save checkpoint
        if save_checkpoints:
            checkpoint = mk_checkpoint(epoch, model, scheduler, optimizer)
            metrics_df = DF.from_records(metrics)
            best_model_metric = "validation_loss" if "validation_loss" in metrics_df.columns else "train_epoch_loss"
            is_best_checkpoint = (
                DF.from_records(metrics)
                .eval(f"min_{best_model_metric} = {best_model_metric}.min()")
                .eval(f"is_best_{best_model_metric} = {best_model_metric} == min_{best_model_metric}")
                .dropna(subset=f"is_best_{best_model_metric}")
                .iloc[-1]
                .loc[f"is_best_{best_model_metric}"]
            )
            save_checkpoint(checkpoint, checkpoints_dir, is_best_checkpoint)

    return DF.from_records(metrics)

def evaluate_model(model: torch.nn.Module, critirion:callable, validation_loader:DL) -> dict:
    model = model.eval()
    model_device = next(model.parameters()).device

    total_test_loss = 0
    for x, y in track(validation_loader, description="Evaluating...", transient=True):
        x = x.to(model_device)
        y = y.to(model_device)
        y_pred = model(x)
        total_test_loss += critirion(y_pred, y).item()
    total_test_loss /= len(validation_loader)

    return {"validation_loss": total_test_loss}

def mk_checkpoint(
        epoch:int,
        model: torch.nn.Module,
        scheduler: LRScheduler,
        optimizer: torch.optim.Optimizer
    ) -> dict:
    return {
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
    }

def save_checkpoint(checkpoint:dict, parent_dir:str, is_best_checkpoint=False):
    # Create model name
    checkpoint_filename = mk_date_now_str() + ".pth"
    # Save model
    os.makedirs(parent_dir, exist_ok=True)
    checkpoint_path = os.path.join(parent_dir, checkpoint_filename)
    torch.save(checkpoint, checkpoint_path)
    mk_symlink(checkpoint_path, os.path.join(parent_dir, "latest_checkpoint.pth"))
    if is_best_checkpoint:
        mk_symlink(checkpoint_path, os.path.join(parent_dir, "best_checkpoint.pth"))

def mk_date_now_str() -> str:
    return datetime.now().strftime("%d-%m-%Y_%H-%M")

def mk_symlink(dest:str, symlink_path:str):
    if os.path.islink(symlink_path) or os.path.exists(symlink_path):
        os.remove(symlink_path)
    os.symlink(dest, symlink_path)

In [47]:
TRAINING_EPOCHS = 30
STARTING_LR = 0.0005
model = Resnet(17, 4, 256, 18)
optimizer = torch.optim.AdamW(model.parameters(), STARTING_LR)
constant_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, total_iters=len(data_loader) * TRAINING_EPOCHS)
training_metrics = fit(TRAINING_EPOCHS, model, constant_lr_scheduler, optimizer, data_loader, nn.CrossEntropyLoss())

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

In [40]:
training_metrics

Unnamed: 0,step,epoch,batch_train_loss,lr,train_epoch_loss
0,0,0,2.890818,0.001,
1,1,0,2.873779,0.001,
2,2,0,2.846834,0.001,
3,3,0,2.824693,0.001,
4,4,0,2.907222,0.001,
...,...,...,...,...,...
3835,3835,29,2.658956,0.001,
3836,3836,29,2.715559,0.001,
3837,3837,29,2.689245,0.001,
3838,3838,29,2.735331,0.001,


In [49]:
import plotly.express as px

px.scatter(training_metrics, y="batch_train_loss", trendline="ewm", trendline_options={"com":60}, trendline_color_override="red")