In [None]:
%load_ext autoreload
%autoreload 2

# neural imaging
import nibabel as nib

import os
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append("../../")
import utils
if not utils.hpc.running_on_hpc():
    import kagglehub

import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import utils

from utils.datasets import BraTSDataset3D

from utils.metrics import (
    dice_score,
    dice_score_background,
    dice_score_necrotic,
    dice_score_edema,
    dice_score_enhancing,
    iou_score,
    iou_score_background,
    iou_score_necrotic,
    iou_score_edema,
    iou_score_enhancing
)

In [None]:
# Training configuration
NUM_EPOCHS = 60
MODEL_NAME = "temp" # change for each model!!
MODEL_SAVE_PATH = f'checkpoints/{MODEL_NAME}.pth'
BATCH_SIZE = 64
NUM_WORKERS = 1
N_SLICES = 5 # for 2.5D model

In [None]:
TRAIN_DATASET_PATH = utils.datasets.load_dataset()

In [None]:
all_patients = utils.data.load_patients(TRAIN_DATASET_PATH)

train_patients, val_patients, test_patients = utils.data.split_patients(all_patients)

print(f"Train: {len(train_patients)} patients")
print(f"Val: {len(val_patients)} patients")
print(f"Test: {len(test_patients)} patients")


train_dataset = BraTSDataset3D(data_dir=TRAIN_DATASET_PATH, patient_ids=train_patients, transform=None, in_memory=False)
val_dataset = BraTSDataset3D(data_dir=TRAIN_DATASET_PATH, patient_ids=val_patients, transform=None, in_memory=False)
test_dataset = BraTSDataset3D(data_dir=TRAIN_DATASET_PATH, patient_ids=test_patients, transform=None, in_memory=False)


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
device = utils.extra.get_device()

In [None]:
# Initialize model
model = utils.models.UNet3D(n_channels=4, n_classes=4).to(device)

# Loss function - CrossEntropyLoss for multi-class segmentation
criterion = utils.losses.CombinedLoss()

# Optimizer
learning_rate = 1e-4
weight_decay = 1e-5

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

# Early stopping patience
early_stopping_patience = 10

print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

In [None]:
save_history_fn = lambda history : utils.visualizations.history_to_json(
    history=history,
    model_name=MODEL_NAME,
    save_dir='checkpoints',
    
    # Configuration parameters
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    optimizer='Adam',
    loss_function='CombinedLoss',
    model_type='UNet3D',
    model_parameters=sum(p.numel() for p in model.parameters()),
    dataset_type='3D',
    early_stopping_patience=early_stopping_patience,
    augmentations=False,
    device=str(device),
)

In [None]:
history = utils.training.train_loop(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=NUM_EPOCHS,
    device=device,
    primary_metric='dice',
    scheduler=scheduler,
    save_best_model=True,
    model_save_path=MODEL_SAVE_PATH,
    early_stopping_patience=early_stopping_patience,
    save_history_fn=save_history_fn
)