# Training model
The goal of this notebook is to train a model and upload it to kaggle.

## Imports

In [12]:
from os.path import join
from typing import Optional

import torch
import numpy as np
from torch import nn, Tensor
from kagglehub import dataset_download
from torch.utils.data import Dataset, DataLoader

In [13]:
# TODO: Switch to TensorDataset w/ cross validation splits
class CMIDataset(Dataset):
    def __init__(self, use_agg_tof:bool, subset:Optional[int]=None, force_download=False):
        super().__init__()
        dataset_path = dataset_download("mauroabidalcarrer/prepocessed-cmi-2025", force_download=force_download)
        x_path  = join(dataset_path, "tof_meaned_X.npy" if use_agg_tof else "X.npy")
        self.x = np.load(x_path, mmap_mode="r").swapaxes(1, 2)
        self.y = np.load(join(dataset_path, "Y.npy"), mmap_mode="r")
        if subset is not None:
            self.x = self.x[:subset]
            self.y = self.y[:subset]

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx].copy(), self.y[idx].copy()
    
dataset = CMIDataset(100, force_download=False)
data_loader = DataLoader(dataset, 128, shuffle=True)

Resuming download from 0 bytes (296905494 bytes left)...
Resuming download from https://www.kaggle.com/api/v1/datasets/download/mauroabidalcarrer/prepocessed-cmi-2025?dataset_version_number=7 (0/296905494) bytes left.


100%|██████████| 283M/283M [00:32<00:00, 9.26MB/s] 

Extracting files...





In [21]:
from itertools import pairwise

class ResidualBlock(nn.Module):
    def __init__(self, in_chns:int, out_chns:int):
        super().__init__()
        self.blocks = nn.Sequential(
            nn.Conv1d(in_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
            nn.ReLU(),
            nn.Conv1d(out_chns, out_chns, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_chns),
        )
        if in_chns == out_chns:
            self.skip_connection = nn.Identity() 
        else:
            # TODO: set bias to False ?
            self.skip_connection = nn.Sequential(
                nn.Conv1d(in_chns, out_chns, 1),
                nn.BatchNorm1d(out_chns)
            )

    def forward(self, x:Tensor) -> Tensor:
        activaition_maps = self.skip_connection(x) + self.blocks(x)
        return nn.functional.relu(activaition_maps)

class Resnet(nn.Module):
    def __init__(
            self,
            in_channels:int,
            depth:int,
            # n_res_block_per_depth:int
        ):
        super().__init__()
        chs_per_depth = [in_channels * 2 ** i for i in range(depth)]
        blocks_chns_it = pairwise(chs_per_depth)
        self.res_blocks = [ResidualBlock(in_chns, out_chns) for in_chns, out_chns in blocks_chns_it]
        self.res_blocks = nn.ModuleList(self.res_blocks)
        
    def forward(self, x:Tensor) -> Tensor:
        out = x
        for res_block in self.res_blocks:
            out = nn.functional.max_pool1d(res_block(out), 2)
        return out

model = Resnet(17, 4)

In [22]:
model

Resnet(
  (res_blocks): ModuleList(
    (0): ResidualBlock(
      (blocks): Sequential(
        (0): Conv1d(17, 34, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): BatchNorm1d(34, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv1d(34, 34, kernel_size=(3,), stride=(1,), padding=(1,))
        (4): BatchNorm1d(34, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip_connection): Sequential(
        (0): Conv1d(17, 34, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(34, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResidualBlock(
      (blocks): Sequential(
        (0): Conv1d(34, 68, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): BatchNorm1d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv1d(68, 68, kernel_size=(3,), stride=(1,), padding=(1,))
        (4): BatchNorm1d(68, eps=1e-05

In [23]:
x, out = dataset[0]

In [24]:
x.shape

(17, 114)

In [25]:
x, out = next(iter(data_loader))

In [26]:
x.shape

torch.Size([128, 17, 114])

In [27]:
y_pred = model(x)
y_pred.shape

torch.Size([128, 136, 14])