<a href="https://colab.research.google.com/github/Dylan-Geraci/neuroimaging-tumor-detector/blob/main/notebooks/02_model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model Training

## Notebook Setup

Import Libraries

In [12]:
# --- Standard ---
import os, json, math, time
from collections import Counter

# --- Numerical / data ---
import numpy as np
import pandas as pd

# --- Imaging & plotting ---
from PIL import Image

# --- Torch / ML ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# --- Metrics ---
from sklearn.metrics import f1_score, accuracy_score

# --- Vision ---
from torchvision import transforms

# --- Pretrained models ---
import timm

# --- Google Drive ---
from google.colab import drive

import time, json, torch
from torch.amp import autocast, GradScaler

Set Training Data Path

In [2]:
drive.mount('/content/drive')
TRAIN_PATH = "/content/drive/MyDrive/neuro-imaging/data/Training"

Mounted at /content/drive


In [3]:
USE_SAVED_SPLITS = False
SPLITS_DIR = "/content/drive/MyDrive/neuro-imaging/splits"

## Path and Constants

In [4]:
TRAIN_PATH = "/content/drive/MyDrive/neuro-imaging/data/Training"
SPLITS_DIR = "/content/drive/MyDrive/neuro-imaging/splits"
OUT_DIR    = "/content/drive/MyDrive/neuro-imaging/models"
os.makedirs(OUT_DIR, exist_ok=True)

IMG_SIZE = 224
BATCH_SIZE = 32
LR = 3e-4
EPOCHS = 10
PATIENCE = 2
SEED = 42

## Reproducing

In [5]:
def set_seed(seed=SEED):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


## Load Splits and Class Maps

In [6]:
def _read_paths(txt_path):
    with open(txt_path, "r") as f:
        return [ln.strip() for ln in f if ln.strip()]

with open(os.path.join(SPLITS_DIR, "class_to_idx.json"), "r") as f:
    class_to_idx = json.load(f)
idx_to_class = {v:k for k,v in class_to_idx.items()}
num_classes = len(class_to_idx)
print("Classes:", [idx_to_class[i] for i in range(num_classes)])

train_paths = _read_paths(os.path.join(SPLITS_DIR, "train.txt"))
val_paths   = _read_paths(os.path.join(SPLITS_DIR, "val.txt"))
print(f"Loaded {len(train_paths)} train, {len(val_paths)} val files")


Classes: ['glioma', 'meningioma', 'notumor', 'pituitary']
Loaded 4855 train, 857 val files


## Transforms

In [7]:
train_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

val_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

## Dataset from path list

In [8]:
class PathDataset(Dataset):
    def __init__(self, paths, class_to_idx, transform=None):
        self.paths = paths
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        p = self.paths[i]
        cls_name = os.path.basename(os.path.dirname(p))
        y = self.class_to_idx[cls_name]
        img = Image.open(p)
        if self.transform:
            img = self.transform(img)
        return img, y

train_ds = PathDataset(train_paths, class_to_idx, train_tfms)
val_ds   = PathDataset(val_paths,   class_to_idx, val_tfms)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

## Class weights (handle imbalances)

In [9]:
train_counts = Counter([os.path.basename(os.path.dirname(p)) for p in train_paths])
counts_by_idx = np.array([train_counts[idx_to_class[i]] for i in range(num_classes)], dtype=np.float32)
weights = (len(train_paths) / (num_classes * counts_by_idx))
class_weights = torch.tensor(weights, dtype=torch.float32, device=device)
print("Class weights:", weights.tolist())

Class weights: [1.080810308456421, 1.0665642023086548, 0.8950958847999573, 0.980411946773529]


## Model, loss, optimizer, scaler

In [10]:
model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=num_classes)
model.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))


## Train / eval helpers

In [11]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    all_preds, all_targets = [], []
    total_loss = 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            logits = model(x)
            loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0)
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.cpu().numpy())
        all_targets.append(y.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    val_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(all_targets, all_preds)
    f1_macro = f1_score(all_targets, all_preds, average="macro")
    return {"loss": val_loss, "acc": acc, "f1_macro": f1_macro}

def train_one_epoch(model, loader, device):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

## Training loop with early stopping on val macro-F1

In [1]:
# Verify GPU
import torch, os, subprocess, textwrap
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    try:
        print(subprocess.check_output(["nvidia-smi"], text=True))
    except Exception:
        pass

# Copy dataset from Drive -> local (much faster I/O)
!mkdir -p /content/data/Training
!rsync -ah --delete --info=progress2 "/content/drive/MyDrive/neuro-imaging/data/Training/" "/content/data/Training/"

CUDA available: True
Mon Aug 25 19:15:21 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   44C    P8             10W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                           