In [None]:
from os import listdir
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
import sys
sys.path.append("./python")
from MRIPreMappedDataset import SampleMapper, MRIPreMappedDataset
from torch import nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torch

In [None]:
# =============================
# Model Definition
# =============================
class CNN2D(nn.Module):
    def __init__(self, n_channels=1, n_classes=1):
        """
        2D CNN for multi-channel MRI classification

        Args:
            n_channels: Number of input channels (MRI modalities)
            n_classes: Tumor/No-Tumor binary classification
        """
        super().__init__()

        self.conv_layers = nn.Sequential(
            # Block 1: n_channels -> 32
            nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 240x240 -> 120x120

            # Block 2: 32 -> 64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 120x120 -> 60x60

            # Block 3: 64 -> 128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 60x60 -> 30x30

            # Block 4: 128 -> 256
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 30x30 -> 15x15
        )

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(256 * 15 * 15, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, n_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc_layers(x)
        return x

# =============================================
# Training routine
# =============================================
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cuda'):
    """
    Train the CNN model

    Returns:
        train_losses, val_losses, train_accs, val_accs
    """
    model = model.to(device)

    train_losses, val_losses = [], []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            model.train()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()

        val_loss = val_loss / len(val_loader)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}")

    return train_losses, val_losses

In [None]:
# Construct patient list and related tumor grade
data_dir = "../data"
full_patient_list = []
for patient in listdir(data_dir):
    if "FU" in patient:
        continue
    full_patient_list.append(patient.replace("_nifti", ""))

tumor_grade = []
meta_data = pd.read_csv("../processed-data/UCSF-PDGM-metadata_v5.csv")
meta_data["ID"] = meta_data["ID"].apply(lambda x: "-".join(x.split("-")[:-1]) + "-" + x.split("-")[-1].rjust(4, "0"))
grade_key = "WHO CNS Grade"
for patient_id in full_patient_list:
    grade = meta_data[grade_key].loc[meta_data["ID"] == patient_id]
    tumor_grade.append(grade.values[0])
print(len(full_patient_list), len(tumor_grade))

In [None]:
data_dir = "../data"
patient_id_list = full_patient_list
full_sample_map = SampleMapper(data_dir,
                               patient_id_list=patient_id_list,
                               samples_per_patient_per_label=2,
                               min_relative_brain_area_per_sample=.25,
                               mri_axis=2,
                               random_seed=360)

In [None]:
# Check that the splits are properly stratified with respect to tumor class.
train_patients, test_patients, train_tumor, test_tumor = train_test_split(full_patient_list, tumor_grade, test_size=0.2, random_state=360, stratify=tumor_grade)
print("N training patients:", len(train_patients))
print("N test patients:", len(test_patients))
print("N train tumor grade:", len(train_tumor))
classes, counts = np.unique(train_tumor, return_counts=True)
print("\tTraining tumor classes:", classes)
print("\tTraining tumor class proportion:", counts/counts.sum())
print("N test tumor grade:", len(test_tumor))
classes, counts = np.unique(test_tumor, return_counts=True)
print("\tTest tumor classes:", classes) 
print("\tTest tumor class proportion:", counts/counts.sum())

In [None]:
import time

full_data_map = full_sample_map.data_map
mri_axis = full_sample_map.mri_axis

## Separate the training data into 4 batches for 4Fold Cross validation.
n_splits = 4
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=360)

# Dataset sampling parameters
selected_modalities = ["T1"]

# DataLoader parameters
batch_size = 32

n_epochs = 20
models = []
loss = {}
accuracy = {}

start_time = time.time()
for i, (train_indices, validation_indices) in enumerate(skf.split(train_patients, train_tumor)):
    # Multiple slices are taken from the MRI of a single patient_id. 
    # Dataset indices for these slices are pulled from the get_indices_from_patient_list method.
    print("Currently running through fold", i)

    training_patients = np.array(full_patient_list)[train_indices]
    validation_patients = np.array(full_patient_list)[validation_indices]

    training_data_map = full_data_map.loc[full_data_map["patient_id"].isin(training_patients)]
    validation_data_map = full_data_map.loc[full_data_map["patient_id"].isin(validation_patients)]

    
    print("Loading the training dataset...")
    load_start_time = time.time()
    training_dataset = MRIPreMappedDataset(data_dir,
                                           training_data_map,
                                           selected_modalities=selected_modalities,
                                           mri_axis=mri_axis)
    print("Finished loading the training dataset!")
    print(f"Time to load: {(time.time() - load_start_time)/60} minutes")
    load_start_time = time.time()
    print("Loading the validation dataset...")
    validation_dataset = MRIPreMappedDataset(data_dir,
                                               validation_data_map,
                                               selected_modalities=selected_modalities,
                                               mri_axis=mri_axis)
    print("Finished loading the validation dataset!")
    print(f"Time to load: {(time.time() - load_start_time)/60:.2f} minutes")
    print("N training samples:", len(training_dataset))
    print("N validation samples:" , len(validation_dataset))

    print("Building dataloaders...")
    training_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

    print("Building model and setting the criterion.")
    model = CNN2D(n_channels=len(selected_modalities), n_classes=1)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)


    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print("Training!")
    load_start_time = time.time()
    train_losses, val_losses = train_model(
        model, training_loader, validation_loader, criterion, optimizer,
        num_epochs=n_epochs, device=device
    )
    print("Finished training!")
    print("Time to train:", (time.time() - load_start_time) / 60, "minutes")
    print("Losses:")
    print("\tTraining:", train_losses)
    print("\tValidation:", val_losses)
    print("Accuracy:")
    loss[i] = [train_losses, val_losses]
    models.append(model)


In [None]:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2,2)
for i, ax in enumerate(axes.flatten()):
    _x = loss[i][0]
    _y = loss[i][1]
    ax.plot(range(len(_x)),_x)
    ax.plot(range(len(_y)),_y)
