# Training model
The goal of this notebook is to train a model and upload it to kaggle.

## Imports

In [1]:
from itertools import pairwise
from os.path import join, realpath
from typing import Optional, Literal

import torch
import numpy as np
import plotly.express as px
from torch import nn, Tensor
from kagglehub import dataset_download, model_upload, whoami
from torch.optim.lr_scheduler import ConstantLR
from torch.utils.data import Dataset, DataLoader

In [2]:
BATCH_SIZE = 256

In [3]:
# TODO: Switch to TensorDataset w/ cross validation splits
class CMIDataset(Dataset):
    def __init__(
        self,
        parent_dir: str,
        split: Optional[Literal["train", "validation"]]=None,
        subset: Optional[int]=None,
        force_download=False
    ):
        super().__init__()
        dataset_path = dataset_download(
            handle="mauroabidalcarrer/prepocessed-cmi-2025",
            force_download=force_download
        )
        parent_dir = join(dataset_path, "preprocessed_dataset", parent_dir)
        split = "" if split is None else split + "_"
        self.x = np.load(join(parent_dir, f"{split}X.npy"), mmap_mode="r").swapaxes(1, 2)
        self.y = np.load(join(parent_dir, f"{split}Y.npy"), mmap_mode="r")
        if subset is not None:
            self.x = self.x[:subset]
            self.y = self.y[:subset]

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx].copy(), self.y[idx].copy()

train_dataset = CMIDataset("fold_0", "train", force_download=False)
train_data_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)

validation_dataset = CMIDataset("fold_0", "validation", force_download=False)
validation_data_loader = DataLoader(train_dataset, BATCH_SIZE * 4, shuffle=False)

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_chns:int, out_chns:int):
        super().__init__()
        self.blocks = nn.Sequential(
            nn.Conv1d(in_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
            nn.ReLU(),
            nn.Conv1d(out_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
        )
        if in_chns == out_chns:
            self.skip_connection = nn.Identity() 
        else:
            # TODO: set bias to False ?
            self.skip_connection = nn.Sequential(
                nn.Conv1d(in_chns, out_chns, 1),
                nn.BatchNorm1d(out_chns)
            )

    def forward(self, x:Tensor) -> Tensor:
        activaition_maps = self.skip_connection(x) + self.blocks(x)
        return nn.functional.relu(activaition_maps)

class Resnet(nn.Module):
    def __init__(
            self,
            in_channels:int,
            depth:int,
            # n_res_block_per_depth:int,
            mlp_width:int,
            n_class:int,
        ):
        super().__init__()
        chs_per_depth = [in_channels * 2 ** i for i in range(depth)]
        blocks_chns_it = pairwise(chs_per_depth)
        self.res_blocks = [ResidualBlock(in_chns, out_chns) for in_chns, out_chns in blocks_chns_it]
        self.res_blocks = nn.ModuleList(self.res_blocks)
        self.mlp_head = nn.Sequential(
            nn.LazyLinear(mlp_width),
            nn.ReLU(),
            nn.Linear(mlp_width, n_class),
            nn.Softmax(dim=1),
        )
        
        
    def forward(self, x:Tensor) -> Tensor:
        activation_maps = x
        for res_block in self.res_blocks:
            activation_maps = nn.functional.max_pool1d(res_block(activation_maps), 2)
        out = activation_maps.view(activation_maps.shape[0], -1)
        out = self.mlp_head(out)
        return out

model = Resnet(
    in_channels=17,
    depth=4,
    mlp_width=256,
    n_class=18,
)

In [8]:
import os
from datetime import datetime

import torch
from torch import nn
import plotly.express as px
from rich.progress import Progress, Task, track
from pandas import DataFrame as DF
from torch.utils.data import DataLoader as DL
from torch.optim.lr_scheduler import LRScheduler


def fit(epochs:int,
        model: nn.Module,
        scheduler: LRScheduler,
        optimizer: torch.optim.Optimizer,
        train_loader: DL,
        criterion: callable=nn.L1Loss(),
        evaluation_func: callable=None,
        validation_loader: DL=None,
        save_checkpoints=True,
    ) -> tuple[DF, str]:
    """
    Returns:
        (training_metrics, path_to_checkpoints)
    """
    # Setup
    metrics: list[dict] = []
    checkpoints_dir = os.path.join("checkpoints", mk_date_now_str())
    step = 0
    model_device = next(model.parameters()).device
    # Training loop
    with Progress() as progress:
        task: Task = progress.add_task(
            "training...",
            # f"epoch: {epoch + 1}, batch_loss: ...",
            total=len(train_loader),
        )
        for epoch in range(epochs):
            progress.update(
                task,
                description=f"epoch: {epoch}",
                completed=0,
            )
            total_epoch_loss = 0
            total_accuracy = 0
            for batch_idx, (x, y) in enumerate(train_loader):
                # forward
                x = x.to(model_device)
                y = y.to(model_device)
                model.train()
                optimizer.zero_grad()
                y_pred: Tensor = model(x)
                loss_value = criterion(y_pred, y)
                # Verify loss value
                if torch.isnan(loss_value).any().item():
                    progress.print("Warning: Got NaN loss, something went wrong.")
                    return DF.from_records(metrics) 
                if torch.isinf(loss_value).any().item():
                    progress.print("Warning: Got infinite loss, something went wrong.")
                    return DF.from_records(metrics) 
                # TODO: Use gradient clipping?
                loss_value.backward()
                optimizer.step()
                if step > 0: # If it's not the first training step
                    # Call the scheduler step method, idk why it throws an error otherwise
                    scheduler.step()
                # metrics
                total_epoch_loss += loss_value.item()
                batch_accuracy = (
                    (y.max(dim=1)[1] == y_pred.max(dim=1)[1])
                    .sum()
                    .item()
                ) / y.shape[0]
                total_accuracy += batch_accuracy
                metrics.append({
                    "step": step,
                    "epoch": epoch,
                    "batch_train_loss": loss_value.item(),
                    "lr": optimizer.state_dict()["param_groups"][-1]["lr"],
                    "batch_accuracy": batch_accuracy,
                })
                step += 1
                progress.update(
                    task,
                    advance=1,
                    description=f"epoch: {epoch}, batch_loss: {(total_epoch_loss / (batch_idx+1)):.2f}"
                )
            # Post epoch evalution
            metrics[-1]["train_epoch_loss"] = total_epoch_loss / len(train_loader)
            metrics[-1]["train_epoch_accuracy"] = total_accuracy / len(train_data_loader)
            if evaluation_func:
                progress.update(
                    task,
                    completed=0,
                    description=f"epoch: {epoch}, evaluating..."
                )
                eval_metrics = evaluation_func(model, criterion, validation_loader)
                progress.print("validation loss:", eval_metrics["validation_loss"])
                metrics[-1].update(eval_metrics)
            # Save checkpoint
            if save_checkpoints:
                checkpoint = mk_checkpoint(epoch, model, scheduler, optimizer)
                metrics_df = DF.from_records(metrics)
                best_model_metric = "validation_loss" if "validation_loss" in metrics_df.columns else "train_epoch_loss"
                is_best_checkpoint = (
                    DF.from_records(metrics)
                    .eval(f"min_{best_model_metric} = {best_model_metric}.min()")
                    .eval(f"is_best_{best_model_metric} = {best_model_metric} == min_{best_model_metric}")
                    .dropna(subset=f"is_best_{best_model_metric}")
                    .iloc[-1]
                    .loc[f"is_best_{best_model_metric}"]
                )
                save_checkpoint(checkpoint, checkpoints_dir, is_best_checkpoint)

    return DF.from_records(metrics), checkpoints_dir

def evaluate_model(
        model: torch.nn.Module,
        critirion:callable,
        validation_loader:DL,
    ) -> dict:
    model = model.eval()
    model_device = next(model.parameters()).device

    total_accuracy = 0
    total_test_loss = 0
    for x, y in validation_loader:
        x = x.to(model_device)
        y = y.to(model_device)
        y_pred = model(x)
        total_test_loss += critirion(y_pred, y).item()
        batch_accuracy = (
            (y.max(dim=1)[1] == y_pred.max(dim=1)[1])
            .sum()
            .item()
        ) / y.shape[0]
        total_accuracy += batch_accuracy

    return {
        "validation_loss": total_test_loss / len(validation_loader),
        "validation_accuracy": total_accuracy / len(validation_loader),
    }

def mk_checkpoint(
        epoch:int,
        model: torch.nn.Module,
        scheduler: LRScheduler,
        optimizer: torch.optim.Optimizer
    ) -> dict:
    return {
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
    }

def save_checkpoint(checkpoint:dict, parent_dir:str, is_best_checkpoint=False):
    # Create model name
    checkpoint_filename = f"epoch_{checkpoint['epoch']}_{mk_date_now_str()}.pth"
    # Save model
    os.makedirs(parent_dir, exist_ok=True)
    checkpoint_path = os.path.join(parent_dir, checkpoint_filename)
    torch.save(checkpoint, checkpoint_path)
    mk_symlink(checkpoint_path, os.path.join(parent_dir, "latest_checkpoint.pth"))
    if is_best_checkpoint:
        mk_symlink(checkpoint_path, os.path.join(parent_dir, "best_checkpoint.pth"))

def mk_date_now_str() -> str:
    return datetime.now().strftime("%d-%m-%Y_%H-%M")

def mk_symlink(dest:str, symlink_path:str):
    if os.path.islink(symlink_path) or os.path.exists(symlink_path):
        os.remove(symlink_path)
    os.symlink(dest, symlink_path)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
TRAINING_EPOCHS = 1
STARTING_LR = 0.0005
model = (
    Resnet(
        in_channels=17,
        depth=4,
        mlp_width=256,
        n_class=18
    )
)
optimizer = torch.optim.AdamW(model.parameters(), STARTING_LR)
constant_lr_scheduler = ConstantLR(optimizer, factor=1, total_iters=len(train_data_loader) * TRAINING_EPOCHS)
training_metrics, model_checkpoints = fit(
    epochs=TRAINING_EPOCHS,
    model=model,
    scheduler=constant_lr_scheduler,
    optimizer=optimizer,
    train_loader=train_data_loader,
    criterion=nn.CrossEntropyLoss(),
    evaluation_func=evaluate_model,
    validation_loader=validation_data_loader,
    save_checkpoints=True,
)

Output()

In [10]:
(
    px.scatter(
    (
        training_metrics
        .melt(
            id_vars="step",
            value_vars=[
                "batch_train_loss",
                "batch_accuracy",
            ]
        )
    ),
    x="step",
    y="value",
    facet_row="variable",
    trendline="ewm",
    trendline_options={"com": 30},
    trendline_color_override="red",
    title="batch metrics",
    )
    .update_yaxes(matches=None)
    .show()
)
px.line(
    (
        training_metrics
        .query("validation_loss.notna()")
        .melt(
            id_vars="step",
            value_vars=['train_epoch_loss', 'train_epoch_accuracy', 'validation_loss', 'validation_accuracy']
        )
    ),
    x="step",
    y="value",
    color="variable",
    facet_row="variable",
    title="epoch metrics",
    render_mode="line+marker",
).update_yaxes(matches=None)

In [None]:
training_metrics.query("validation_accuracy.notna()")["validation_accuracy"]

25      0.278908
51      0.353251
77      0.392261
103     0.416454
129     0.461623
155     0.486220
181     0.487553
207     0.492515
233     0.536963
259     0.555424
285     0.559028
311     0.570649
337     0.589603
363     0.609915
389     0.627314
415     0.628471
441     0.634847
467     0.645192
493     0.656354
519     0.672579
545     0.675448
571     0.680433
597     0.689282
623     0.681468
649     0.702517
675     0.707957
701     0.715432
727     0.714037
753     0.728488
779     0.737437
805     0.741683
831     0.743078
857     0.744413
883     0.747802
909     0.760658
935     0.768252
961     0.770484
987     0.772576
1013    0.772218
1039    0.775088
1065    0.776782
1091    0.777001
1117    0.779573
1143    0.780310
1169    0.782223
1195    0.798069
1221    0.800780
1247    0.801138
1273    0.803331
1299    0.804786
1325    0.804467
1351    0.805962
1377    0.806241
1403    0.801277
1429    0.802493
1455    0.805723
1481    0.806560
1507    0.807636
1533    0.8091

In [11]:
handle = f"{whoami()}/cmi-resnet/pytorch/1"
model_path = join(model_checkpoints, realpath("best_checkpoint.pth"))
model_upload(handle, model_path)

Kaggle credentials successfully validated.
Uploading Model https://www.kaggle.com/models/{'username': 'mauroabidalcarrer'}/cmi-resnet/pytorch/1 ...
Model 'cmi-resnet' does not exist or access is forbidden for user '{'username': 'mauroabidalcarrer'}'. Creating or handling Model...


BackendError: Permission 'models.create' was denied