# Training model
The goal of this notebook is to train a model and upload it to kaggle.

## Setup

### Imports

In [1]:
import os
from glob import glob
from datetime import datetime
from itertools import pairwise
from os.path import join, realpath
from typing import Optional, Literal

import torch
import numpy as np
import pandas as pd
import plotly.express as px
from torch import nn, Tensor
from pandas import DataFrame as DF
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from rich.progress import Progress, Task, track
from torch.optim.lr_scheduler import ConstantLR, LRScheduler
from kagglehub import dataset_download, model_upload, whoami

### Dataset Setup

In [2]:
BATCH_SIZE = 256
dataset_path = dataset_download(
    handle="mauroabidalcarrer/prepocessed-cmi-2025",
)

In [3]:
from torch.utils.data import TensorDataset

class CMIDataset(TensorDataset):
    def __init__(
        self,
        parent_dir: str,
        split: Optional[Literal["train", "validation"]]=None,
        subset: Optional[int]=None,
        force_download=False
    ):
        dataset_path = dataset_download(
            handle="mauroabidalcarrer/prepocessed-cmi-2025",
            force_download=force_download
        )
        parent_dir = join(dataset_path, "preprocessed_dataset", parent_dir)
        split = "" if split is None else split + "_"
        x = np.load(join(parent_dir, f"{split}X.npy")).swapaxes(1, 2)
        y = np.load(join(parent_dir, f"{split}Y.npy"))
        if subset is not None:
            x = x[:subset]
            y = y[:subset]
        super().__init__(
            torch.from_numpy(x), 
            torch.from_numpy(y),
        )


### device setup

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Model definition

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_chns:int, out_chns:int):
        super().__init__()
        self.blocks = nn.Sequential(
            nn.Conv1d(in_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
            nn.ReLU(),
            nn.Conv1d(out_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
        )
        if in_chns == out_chns:
            self.skip_connection = nn.Identity() 
        else:
            # TODO: set bias to False ?
            self.skip_connection = nn.Sequential(
                nn.Conv1d(in_chns, out_chns, 1),
                nn.BatchNorm1d(out_chns)
            )

    def forward(self, x:Tensor) -> Tensor:
        activaition_maps = self.skip_connection(x) + self.blocks(x)
        return nn.functional.relu(activaition_maps)

class Resnet(nn.Module):
    def __init__(
            self,
            in_channels:int,
            depth:int,
            # n_res_block_per_depth:int,
            mlp_width:int,
            n_class:int,
        ):
        super().__init__()
        chs_per_depth = [in_channels * 2 ** i for i in range(depth)]
        blocks_chns_it = pairwise(chs_per_depth)
        self.res_blocks = [ResidualBlock(in_chns, out_chns) for in_chns, out_chns in blocks_chns_it]
        self.res_blocks = nn.ModuleList(self.res_blocks)
        self.mlp_head = nn.Sequential(
            nn.LazyLinear(mlp_width),
            nn.ReLU(),
            nn.Linear(mlp_width, n_class),
            nn.Softmax(dim=1),
        )
        
        
    def forward(self, x:Tensor) -> Tensor:
        activation_maps = x
        for res_block in self.res_blocks:
            activation_maps = nn.functional.max_pool1d(res_block(activation_maps), 2)
        out = activation_maps.view(activation_maps.shape[0], -1)
        out = self.mlp_head(out)
        return out

## Training functions 

### Training

In [None]:
def fit(epochs:int,
        model: nn.Module,
        scheduler: LRScheduler,
        optimizer: torch.optim.Optimizer,
        train_loader: DL,
        criterion: callable=nn.L1Loss(),
        evaluation_func: callable=None,
        validation_loader: DL=None,
        save_checkpoints=True,
    ) -> tuple[DF, str]:
    """
    Returns:
        (training_metrics, path_to_checkpoints)
    """
    # Setup
    metrics: list[dict] = []
    checkpoints_dir = os.path.join("checkpoints", mk_date_now_str())
    step = 0
    model_device = next(model.parameters()).device
    last_epoch_metric = {}
    # Training loop
    with Progress() as progress:
        task: Task = progress.add_task(
            "training...",
            total=len(train_loader),
        )
        for epoch in range(epochs):
            progress.update(
                task,
                description=f"epoch: {epoch}",
                completed=0,
            )
            total_epoch_loss = 0
            total_accuracy = 0
            for batch_idx, (x, y) in enumerate(train_loader):
                # forward
                x = x.to(model_device)
                y = y.to(model_device)
                model.train()
                optimizer.zero_grad()
                y_pred: Tensor = model(x)
                loss_value = criterion(y_pred, y)
                # Verify loss value
                if torch.isnan(loss_value).any().item():
                    progress.print("Warning: Got NaN loss, stopped training.")
                    return DF.from_records(metrics) 
                if torch.isinf(loss_value).any().item():
                    progress.print("Warning: Got infinite loss, stopped training.")
                    return DF.from_records(metrics) 
                # TODO: Use gradient clipping?
                loss_value.backward()
                optimizer.step()
                if step > 0: # If it's not the first training step, idk why it throws an error otherwise
                    scheduler.step()
                # metrics
                total_epoch_loss += loss_value.item()
                batch_accuracy = (
                    (y.max(dim=1)[1] == y_pred.max(dim=1)[1])
                    .sum()
                    .item()
                ) / y.shape[0]
                total_accuracy += batch_accuracy
                metrics.append({
                    "step": step,
                    "epoch": epoch,
                    "batch_train_loss": loss_value.item(),
                    "lr": optimizer.state_dict()["param_groups"][-1]["lr"],
                    "batch_accuracy": batch_accuracy,
                })
                step += 1
                if "validation_accuracy" in last_epoch_metric:
                    last_validation_acc = "%.2f" % last_epoch_metric["validation_accuracy"]
                else:
                    last_validation_acc = None
                progress.update(
                    task,
                    advance=1,
                    description=f"epoch: {epoch}, batch_loss: {(total_epoch_loss / (batch_idx+1)):.2f}, val. acc: {last_validation_acc}"
                )
            # Post epoch evalution
            metrics[-1]["train_epoch_loss"] = total_epoch_loss / len(train_loader)
            metrics[-1]["train_epoch_accuracy"] = total_accuracy / len(train_loader)
            if evaluation_func:
                progress.update(
                    task,
                    completed=0,
                    description=f"epoch: {epoch}, evaluating..."
                )
                eval_metrics = evaluation_func(model, criterion, validation_loader)
                metrics[-1].update(eval_metrics)
            last_epoch_metric = metrics[-1]
            # Save checkpoint
            if save_checkpoints:
                checkpoint = mk_checkpoint(epoch, model, scheduler, optimizer)
                metrics_df = DF.from_records(metrics)
                best_model_metric = "validation_loss" if "validation_loss" in metrics_df.columns else "train_epoch_loss"
                is_best_checkpoint = (
                    DF.from_records(metrics)
                    .eval(f"min_{best_model_metric} = {best_model_metric}.min()")
                    .eval(f"is_best_{best_model_metric} = {best_model_metric} == min_{best_model_metric}")
                    .dropna(subset=f"is_best_{best_model_metric}")
                    .iloc[-1]
                    .loc[f"is_best_{best_model_metric}"]
                )
                save_checkpoint(checkpoint, checkpoints_dir, is_best_checkpoint)

    return DF.from_records(metrics), checkpoints_dir

def evaluate_model(
        model: torch.nn.Module,
        critirion:callable,
        validation_loader:DL,
    ) -> dict:
    model = model.eval()
    model_device = next(model.parameters()).device

    total_accuracy = 0
    total_test_loss = 0
    for x, y in validation_loader:
        x = x.to(model_device)
        y = y.to(model_device)
        y_pred = model(x)
        total_test_loss += critirion(y_pred, y).item()
        batch_accuracy = (
            (y.max(dim=1)[1] == y_pred.max(dim=1)[1])
            .sum()
            .item()
        ) / y.shape[0]
        total_accuracy += batch_accuracy

    return {
        "validation_loss": total_test_loss / len(validation_loader),
        "validation_accuracy": total_accuracy / len(validation_loader),
    }

def mk_checkpoint(
        epoch:int,
        model: torch.nn.Module,
        scheduler: LRScheduler,
        optimizer: torch.optim.Optimizer
    ) -> dict:
    return {
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
    }

def save_checkpoint(checkpoint:dict, parent_dir:str, is_best_checkpoint=False):
    # Create model name
    checkpoint_filename = f"epoch_{checkpoint['epoch']}_{mk_date_now_str()}.pth"
    # Save model
    os.makedirs(parent_dir, exist_ok=True)
    checkpoint_path = os.path.join(parent_dir, checkpoint_filename)
    torch.save(checkpoint, checkpoint_path)
    mk_symlink(checkpoint_path, os.path.join(parent_dir, "latest_checkpoint.pth"))
    if is_best_checkpoint:
        mk_symlink(checkpoint_path, os.path.join(parent_dir, "best_checkpoint.pth"))

def mk_date_now_str() -> str:
    return datetime.now().strftime("%d-%m-%Y_%H-%M")

def mk_symlink(dest:str, symlink_path:str):
    if os.path.islink(symlink_path) or os.path.exists(symlink_path):
        os.remove(symlink_path)
    os.symlink(dest, symlink_path)

In [15]:
TRAINING_EPOCHS = 2
STARTING_LR = 0.0005
def mk_model() -> nn.Module:
    return (
        Resnet(
            in_channels=17,
            depth=4,
            mlp_width=256,
            n_class=18
        )
    )

def mk_model_and_fit(train_loader:DL, validation_loader:Optional[DL]=None) -> tuple[nn.Module, DF, list[str]]:
    model = mk_model()
    optimizer = torch.optim.AdamW(model.parameters(), STARTING_LR)
    constant_lr_scheduler = ConstantLR(optimizer, factor=1, total_iters=len(train_loader) * TRAINING_EPOCHS)
    training_metrics, model_checkpoints = fit(
        epochs=TRAINING_EPOCHS,
        model=model,
        scheduler=constant_lr_scheduler,
        optimizer=optimizer,
        train_loader=train_loader,
        criterion=nn.CrossEntropyLoss(),
        evaluation_func=evaluate_model if validation_loader else None,
        validation_loader=validation_loader,
        save_checkpoints=True,
    )

    return model, training_metrics, model_checkpoints

### Training metrics plotting

In [8]:
def plt_training_metrics(training_metrics:DF):
    (
        px.scatter(
        (
            training_metrics
            .melt(
                id_vars="step",
                value_vars=[
                    "batch_train_loss",
                    "batch_accuracy",
                ]
            )
        ),
        x="step",
        y="value",
        facet_row="variable",
        trendline="ewm",
        trendline_options={"com": 30},
        trendline_color_override="red",
        title="batch metrics",
        )
        .update_yaxes(matches=None)
        .show()
    )
    if "validation_loss" in training_metrics.columns:
        (
            px.line(
                (
                    training_metrics
                    .query("validation_loss.notna()")
                    .melt(
                        id_vars="step",
                        value_vars=['train_epoch_loss', 'train_epoch_accuracy', 'validation_loss', 'validation_accuracy']
                    )
                ),
                x="step",
                y="value",
                color="variable",
                facet_row="variable",
                title="epoch metrics",
                render_mode="line+marker",
            ).update_yaxes(matches=None)
            .show()
        )    

## Cross validation training

In [9]:
# How many folds do we train the model on ?
NB_CROSS_VALIDATIONS = 2
fold_patterns = join(dataset_path, "preprocessed_dataset", "fold*")
fold_pths = glob(fold_patterns)[:NB_CROSS_VALIDATIONS]
all_training_metrics = {}

for fold_idx, fold_pth in enumerate(fold_pths, 1):
    print("training:", fold_idx)
    train_dataset = CMIDataset(fold_pth, "train")
    print(train_dataset[0][0].shape, train_dataset[0][1].shape)
    train_data_loader = DL(train_dataset, BATCH_SIZE, shuffle=True)
    validation_dataset = CMIDataset(fold_pth, "validation")
    print(validation_dataset[0][0].shape, validation_dataset[0][1].shape)
    validation_data_loader = DL(validation_dataset, BATCH_SIZE, shuffle=False)
    _, training_metrics, _ = mk_model_and_fit(train_data_loader, validation_data_loader)
    all_training_metrics["fold_" + str(fold_idx)] = training_metrics
    plt_training_metrics(training_metrics)
    print("=========================")

all_training_metrics = pd.concat(all_training_metrics)

training: 1
torch.Size([17, 130]) torch.Size([18])
torch.Size([17, 130]) torch.Size([18])


Output()

training: 2
torch.Size([17, 127]) torch.Size([18])


Output()

torch.Size([17, 127]) torch.Size([18])




In [10]:
all_training_metrics

Unnamed: 0,Unnamed: 1,step,epoch,batch_train_loss,lr,batch_accuracy,train_epoch_loss,train_epoch_accuracy,validation_loss,validation_accuracy
fold_1,0,0,0,2.889864,0.0005,0.042969,,,,
fold_1,1,1,0,2.877919,0.0005,0.160156,,,,
fold_1,2,2,0,2.865226,0.0005,0.148438,,,,
fold_1,3,3,0,2.855235,0.0005,0.152344,,,,
fold_1,4,4,0,2.842860,0.0005,0.179688,,,,
...,...,...,...,...,...,...,...,...,...,...
fold_2,47,47,1,2.675174,0.0005,0.316406,,,,
fold_2,48,48,1,2.585952,0.0005,0.398438,,,,
fold_2,49,49,1,2.632515,0.0005,0.355469,,,,
fold_2,50,50,1,2.663874,0.0005,0.316406,,,,


## Full dataset training

In [17]:
full_dataset = CMIDataset("full_dataset")
full_dataset_loader = DL(full_dataset, BATCH_SIZE, shuffle=True)
model, full_train_metrics, checkpoints = mk_model_and_fit(full_dataset_loader)

Output()

## Model upload

In [None]:
handle = f"{whoami()['username']}/cmi-resnet/pyTorch/{mk_date_now_str().replace('_', '-')}"
model_path = realpath(join(checkpoints, "best_checkpoint.pth"))
model_upload(handle, model_path)

Kaggle credentials successfully validated.
Uploading Model https://www.kaggle.com/models/mauroabidalcarrer/cmi-resnet/pyTorch/15-07-2025-15-09 ...


BackendError: Please upload at least one file