In [None]:
pip install monai

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from sklearn.model_selection import train_test_split
import torch
import pydicom
from torch.utils.data import Dataset, DataLoader

# Transformations
from scipy.ndimage import zoom 
from monai.transforms import (
   Compose,
   ScaleIntensityd,
   NormalizeIntensityd,
   ScaleIntensityRanged,
   RandFlipd,
   RandRotate90d,
   RandShiftIntensityd
)

# Model
from monai.networks.nets import DenseNet121, DenseNet169
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

In [None]:
source_dir = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train'
annotations_file = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv'

In [None]:
annotations_df = pd.read_csv(annotations_file,
                            dtype = {"BraTS21ID": str})
annotations_df.head()

## Define the problematic cases

In [None]:
problematic_cases = ['00109', '00123', '00709']

## Split Data by Patient Cases

In [None]:
def get_slice_paths(patient_modality_path):
    all_slices = []
    for slice_img in os.listdir(patient_modality_path):
        if slice_img.endswith('.dcm'):
            all_slices.append(os.path.join(patient_modality_path, slice_img))
    return sorted(all_slices)

In [None]:
def split_data(source_dir, problematic_cases, train_size, test_size, val_size, modality):
        # Get all patient cases
    patient_cases = [ patient_id for patient_id in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, patient_id)) ]
    
    # Filter out the problematic cases
    patient_cases = [case for case in patient_cases if case not in problematic_cases ]
    
    train_val_cases, test_cases = train_test_split(
        patient_cases,
        test_size = test_size,
        random_state = 42
    )
    
    val_size_adjusted = val_size/(train_size + val_size)
    
    train_cases, val_cases = train_test_split(
        train_val_cases,
        test_size = val_size_adjusted,
        random_state =42    )
    return {
        'train':{
            'patient_ids': train_cases,
            'paths': {patient_id: os.path.join(source_dir, patient_id) for patient_id in train_cases},
            'slices': {patient_id: get_slice_paths(os.path.join(source_dir, patient_id, modality)) for patient_id in train_cases}
        },
        'val':{
            'patient_ids': val_cases,
            'paths': {patient_id: os.path.join(source_dir, patient_id) for patient_id in val_cases},
            'slices': {patient_id: get_slice_paths(os.path.join(source_dir, patient_id, modality)) for patient_id in val_cases}     
        
        },
        'test':{
            'patient_ids': test_cases,
            'paths': {patient_id: os.path.join(source_dir, patient_id) for patient_id in test_cases},
            'slices': {patient_id: get_slice_paths(os.path.join(source_dir, patient_id, modality)) for patient_id in test_cases}     
        
        }
    }

In [None]:
# Split data
splits = split_data(
    source_dir=source_dir,
    train_size=0.7,
    test_size=0.15,
    val_size=0.15,
    problematic_cases = problematic_cases,
    modality = 'FLAIR')

In [None]:
# splits['train']['slices']

## Define the Custom BraTs Dataset

In [None]:
class BraTsDataset(Dataset):

    def __init__(self, data_dict, annotations_df, transforms = None, cache_size = 0):
        self.patient_ids = data_dict['patient_ids']
        self.slice_paths = data_dict['slices']
        self.transforms = transforms

        self.cache_size = cache_size
        self.cache = {}
        self.labels = dict(zip(annotations_df['BraTS21ID'], annotations_df['MGMT_value'])) 

    def __len__(self):
        return len(self.patient_ids)

    def load_volume(self, patient_id):
        if patient_id in self.cache:
            return self.cache[patient_id]
        slices = []
        for slice_path in self.slice_paths[patient_id]:
            dicom_image = pydicom.dcmread(slice_path)
            image_2d = dicom_image.pixel_array
            resized_slice = zoom(image_2d, (64/image_2d.shape[0], 64/image_2d.shape[1]))
            slices.append(resized_slice)
           
        volume = np.stack(slices, axis=-1)
        volume = np.expand_dims(volume, axis=0)
        volume_tensor = torch.from_numpy(volume).float()

        # Print range before normalization
        # print(f"Before normalization: [{volume_tensor.min():.3f}, {volume_tensor.max():.3f}]")

        # Add normalization here
        volume_tensor = (volume_tensor - volume_tensor.min()) / (volume_tensor.max() - volume_tensor.min())
       
        # Print range after normalization
        # print(f"After normalization: [{volume_tensor.min():.3f}, {volume_tensor.max():.3f}]")
       
        
        if len(self.cache) < self.cache_size:
            self.cache[patient_id] = volume_tensor
           
        return volume_tensor

    
    def __getitem__(self, idx):
        
        patient_id = self.patient_ids[idx]
        volume = self.load_volume(patient_id)
        label = torch.tensor(self.labels[patient_id])
        
        data = {"image": volume, 
                "patient_id": patient_id,
                "label": label
               }
        
        # Add print statements for debugging
        # print(f"Before transform range: [{data['image'].min():.3f}, {data['image'].max():.3f}]")
        # print("Data keys:", data.keys())
        
        # if self.transforms:
        #     data = self.transforms(data)
            
        # print(f"After transform range: [{data['image'].min():.3f}, {data['image'].max():.3f}]")
            
        
        return data
        
    

In [None]:
train_split = splits['train']
val_split = splits['val']
test_split = splits['test']

In [None]:
# Check split data structure
print("Train split keys:", train_split.keys())
print("First few patient IDs:", train_split['patient_ids'][:3])

# Check if paths are correct
first_patient = train_split['patient_ids'][0]
print("First patient paths:", train_split['slices'][first_patient][:3])

# Create dataset and check
train_dataset = BraTsDataset(train_split, annotations_df)
sample = train_dataset[0]
print(f"Sample shape: {sample['image'].shape}")

## Define the Transformations

In [None]:
train_transforms = Compose([
ScaleIntensityRanged(
        keys=["image"],
        a_min=-76,
        a_max=3158,
        b_min=0.0,
        b_max=1.0,
        clip=True
    ),    
    RandFlipd(keys=["image"], spatial_axis=[0, 1], prob=0.5),
    RandRotate90d(keys=["image"], prob=0.5, spatial_axes=[0, 1]),
    RandShiftIntensityd(keys=["image"], prob=0.5, offsets=0.1),
])



In [None]:
# Add print to verify transform is in chain
print("Transforms:", train_transforms.transforms)

In [None]:
val_transforms = Compose([
   ScaleIntensityd(
       keys=["image"],
       minv=0.0,
       maxv=1.0,
   ),
])

In [None]:
# Create test transforms and loader
test_transforms = Compose([
   ScaleIntensityd(
       keys=["image"],
       minv=0.0,
       maxv=1.0,
   ),
])

In [None]:
train_dataset = BraTsDataset(train_split, annotations_df, transforms=train_transforms)

In [None]:
train_dataset = BraTsDataset(train_split, annotations_df)
val_dataset = BraTsDataset(val_split, annotations_df)
test_dataset = BraTsDataset(test_split, annotations_df)

In [None]:
# Check split data structure
print("Train split keys:", train_split.keys())
print("First few patient IDs:", train_split['patient_ids'][:3])

# Check if paths are correct
first_patient = train_split['patient_ids'][0]
print("First patient paths:", train_split['slices'][first_patient][:3])

# Create dataset and check
train_dataset = BraTsDataset(train_split, annotations_df)
sample= train_dataset[0]
print(f"Train Sample shape: {sample['image'].shape}")
print()

val_dataset = BraTsDataset(val_split, annotations_df)
sample= val_dataset[0]
print(f"Validation Sample shape: {sample['image'].shape}")
print()
test_dataset = BraTsDataset(val_split, annotations_df)
sample= test_dataset[0]
print(f"Test Sample shape: {sample['image'].shape}")

## Define the DataLoader

In [None]:
def custom_collate(batch):
    
    min_val = min([item["image"].min().item() for item in batch])
    max_val = max([item["image"].max().item() for item in batch])

    # print(f"Collate input range: [{min_val:.3f}, {max_val:.3f}]")
    
    max_depth = max([x["image"].shape[-1] for x in batch])  # x[0] is volume
    padded_batch = []
    labels = []
    for data in batch:
        volume = data["image"]
        label = data["label"]
        pad_size = max_depth - volume.shape[-1]
        
        if pad_size > 0:
            padded_volume = torch.nn.functional.pad(volume, (0, pad_size))
            padded_batch.append(padded_volume)
        
        else:
            padded_batch.append(volume)
        labels.append(label)
    
    return { 
        "image": torch.stack(padded_batch), 
        "label": torch.tensor(labels)
           }

In [None]:
# Create DataLoader
train_loader = DataLoader(
   train_dataset,
   batch_size=4,  # Small batch size for 3D data
   shuffle=True,
   num_workers=2,
   collate_fn=custom_collate,
   pin_memory=True  # Faster data transfer to GPU
)


val_loader = DataLoader(
   val_dataset,
   batch_size=4,  # Small batch size for 3D data
   shuffle=True,
   num_workers=2,
   collate_fn=custom_collate,
   pin_memory=True  # Faster data transfer to GPU
)

# Create DataLoader
test_loader = DataLoader(
   test_dataset,
   batch_size=4,  # Small batch size for 3D data
   shuffle=False,
   num_workers=2,
   collate_fn=custom_collate,
   pin_memory=True  # Faster data transfer to GPU
)


In [None]:
def inspect_batch(dataloader, split_name="Unknown"):
    """Inspects a batch from a given dataloader, printing shape and value range."""
    batch = next(iter(dataloader))  # Get a single batch
    
    images, labels = batch["image"], batch["label"]
    
    min_val, max_val = images.min().item(), images.max().item()
    
    print(f"\n--- {split_name} Split Batch Inspection ---")
    print(f"Image Batch Shape: {images.shape}")  # Expected: (batch_size, channels, height, width, depth)
    print(f"Label Batch Shape: {labels.shape}")  # Should match batch_size
    print(f"Value Range: [{min_val:.3f}, {max_val:.3f}]")


In [None]:
# inspect_batch(train_loader, "Train")
# inspect_batch(val_loader, "Validation")
# inspect_batch(test_loader, "Test")


# Define the model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# model = DenseNet121(
#     spatial_dims = 3,
#     in_channels = 1,
#     out_channels = 1
# ).to(device)

model = DenseNet169(
    spatial_dims = 3,
    in_channels = 1,
    out_channels = 1
).to(device)



In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr = 1e-4)

## Set up the Training Loop

In [None]:
def train_epoch(model, train_loader, val_loader, epochs = 50):
    best_val_acc = 0.0
    for epoch in range(epochs):
        
        model.train()  # setting the model in train mode
        train_loss = 0.0 # initialize variable to store the sum of loss
        
        for batch in tqdm(train_loader, desc = f'Epoch {epoch+1}/{epochs} - Training'):
            # Pass through the training set for the no. of epochs
            images = batch['image'].to(device)
            labels = batch['label'].float().to(device)

        
            optimizer.zero_grad()
            outputs = model(images)
        
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            
        # Calculatie average loss for epoch
        avg_train_loss = train_loss/len(train_loader)
        
        # Validation Phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad(): # No gradient calculation needed
            for batch in tqdm(val_loader, desc = f'Epoch {epoch + 1}/{epochs} - Validation'):
                images = batch['image'].to(device)
                labels = batch['label'].float().to(device)
                
                outputs = model(images)
                
                val_loss += criterion(outputs.squeeze(), labels).item()

                predicted = (torch.sigmoid(outputs.squeeze()) > 0.5).int()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
        val_accuracy = 100*correct/total
        avg_val_loss = val_loss/len(val_loader)
        
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
        
        # Print epoch results
        print(f'Epoch [{epoch+1}/{epochs}]')
        print(f'Training Loss: {avg_train_loss:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}')
        print(f'Validation Accuracy: {val_accuracy:.2f}%')
        print('-' * 50)


        
        

In [None]:
# Run training for 50 epochs
num_epochs = 30
train_epoch(model, train_loader, val_loader, epochs=num_epochs)

In [None]:
def test_model(model, test_loader, criterion):
    # Load best model weights
    model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
    model.eval()  # Set to evaluation mode
    
    test_loss = 0.0
    correct = 0
    total = 0
    
    # For storing predictions and true labels
    all_predictions = []
    all_labels = []
    
    test_pbar = tqdm(test_loader, desc='Testing')
    with torch.no_grad():
        for batch in test_pbar:
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(images)
            test_loss += criterion(outputs.squeeze(), labels.float()).item()
            
            predicted = (torch.sigmoid(outputs.squeeze()) > 0.5).int()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Store predictions and labels
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            current_acc = 100 * correct / total
            test_pbar.set_postfix({'acc': f'{current_acc:.2f}%'})
    
    # Calculate final metrics
    test_accuracy = 100 * correct / total
    avg_test_loss = test_loss / len(test_loader)
    
    print('\nTest Results:')
    print(f'Test Loss: {avg_test_loss:.4f}')
    print(f'Test Accuracy: {test_accuracy:.2f}%')
    
    return test_accuracy, avg_test_loss, all_predictions, all_labels


In [None]:
# Run testing
test_accuracy, test_loss, predictions, true_labels = test_model(model, test_loader, criterion)

In [None]:
print('Test Accuracy: ',test_accuracy)
print('Test Loss: ',test_loss)