In [1]:
import kagglehub
dataset_path = kagglehub.dataset_download("awsaf49/brats20-dataset-training-validation")
print("Path to dataset files:", dataset_path)

Path to dataset files: /kaggle/input/brats20-dataset-training-validation


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2 as cv
import os
import time
import gc

import nibabel as nib

from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

print("Random seed set.")

Random seed set.


In [3]:
class MRISegDataset(Dataset):
    def __init__(self, base_path):
        """
        Args:
            base_path (str): Path to the dataset folder.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.base_path = base_path
        self.patients = [p for p in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, p))]
        
    def __len__(self):
        return len(self.patients)

    def __getitem__(self, ind):
        """
        Args:
            idx (int): Index of the patient to fetch data for.

        Returns:
            dict: A dictionary containing 'image', 'segmentation', and optionally 'patient_id'.
        """
        patient_id = self.patients[ind]
        if patient_id == 'BraTS20_Training_355': 
            flair_path = os.path.join(self.base_path, patient_id, "W39_1998.09.19_Segm.nii")
            seg_path = os.path.join(self.base_path, patient_id, "W39_1998.09.19_Segm.nii")
        else:
            flair_path = os.path.join(self.base_path, patient_id, f"{patient_id}_flair.nii")
            seg_path = os.path.join(self.base_path, patient_id, f"{patient_id}_seg.nii")

        flair = nib.load(flair_path).get_fdata()
        seg = nib.load(seg_path).get_fdata()
        seg = np.where(seg==4, 3, seg)

        # Normalization
        flair = (flair - 0) / (900 - 0)

        flair = torch.tensor(flair, dtype=torch.float32).unsqueeze(0)
        seg = torch.tensor(seg, dtype=torch.long)

        sample = {'image': flair, 'segmentation': seg, 'patient_id': patient_id}

        return sample

In [4]:
base_path = dataset_path +'/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
brats_dataset = MRISegDataset(base_path=base_path)

# Example: Access a single sample
sample = brats_dataset[0]
print("Image shape:", sample['image'].shape)
print("Segmentation shape:", sample['segmentation'].shape)
print("Patient ID:", sample['patient_id'])

Image shape: torch.Size([1, 240, 240, 155])
Segmentation shape: torch.Size([240, 240, 155])
Patient ID: BraTS20_Training_083


In [5]:
def dice_coefficient(pred, target, smooth=1e-6):
    pred = pred.contiguous()
    target = target.contiguous()

    pred_class_all = torch.argmax(pred, dim=1)
    
    dice_total = []
    for class_ind in range(4):
        pred_class = (pred_class_all == class_ind).float()
        target_class = (target == class_ind).float()
        
        intersection = (pred_class * target_class).sum(dim=(1,2))
        union = pred_class.sum(dim=(1,2)) + target_class.sum(dim=(1,2))
        
        dice_class = (2. * intersection + smooth) / (union + smooth)
        dice_total.append(dice_class)
    
    return torch.mean(torch.stack(dice_total))

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Currently using: {device}')

Currently using: cuda


In [7]:
class Conv3D(nn.Module):
    def __init__(self):
        super(Conv3D, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.conv4 = nn.Conv3d(16, 4, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
    
    def forward(self,x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.conv4(x)
        return x

model_3d = Conv3D()
print(model_3d)
print(f"Total trainable parameters: {sum(p.numel() for p in model_3d.parameters() if p.requires_grad):,}")

Conv3D(
  (conv1): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv3): Conv3d(32, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv4): Conv3d(16, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (relu): ReLU()
)
Total trainable parameters: 29,876


In [8]:
batch_size = 2
data_loader = DataLoader(brats_dataset, batch_size=batch_size, shuffle=True, pin_memory=True,
                         num_workers=2)
print(f"Number of Batches: {len(data_loader)}")

model_3d = Conv3D()
if torch.cuda.is_available(): 
    model_3d = model_3d.cuda()
    print("Moved to cuda.")
criterion_3d = nn.CrossEntropyLoss()
optimizer_3d = torch.optim.Adam(model_3d.parameters(), lr=1e-3)

Number of Batches: 185
Moved to cuda.


In [9]:
num_epochs = 5
training_losses = []
epoch_times = []

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    model_3d.train()
    running_loss = 0.0

    for batch_idx, batch in enumerate(data_loader):
        batch_start_time = time.time()
        images = batch['image'].to(device) 
        segmentations = batch['segmentation'].squeeze(1).to(device)

        optimizer_3d.zero_grad()
    
        outputs = model_3d(images)
        loss = criterion_3d(outputs, segmentations)
        
        loss.backward()
        optimizer_3d.step()

        running_loss += loss.item()
        
        del images
        del segmentations
        torch.cuda.empty_cache()
        gc.collect()
        
        if batch_idx%300 == 0:
            print(f"Batch {batch_idx}/{len(data_loader)}, Time: {time.time()-batch_start_time:.4f}, Loss: {loss.item():.4f}")
        
    avg_loss = running_loss / len(data_loader)
    epoch_time = time.time()-epoch_start_time
    print(f"Epoch {epoch+1}/{num_epochs}, Time: {epoch_time:.4f}, Loss: {avg_loss:.4f}")

    print("Saving weights")
    torch.save(model_3d.state_dict(), f"Conv3D_weights_{epoch}.pth")
    training_losses.append(avg_loss)
    epoch_times.append(epoch_time)
    
with open("Conv3D.txt", "w") as f:
    f.write("Training_losses: \n")
    for i in training_losses: f.write(str(i)+"\n")
    f.write("\nEpoch_times: \n")
    for i in epoch_times: f.write(str(i)+"\n")

Batch 0/185, Time: 3.8468, Loss: 1.3738
Epoch 1/5, Time: 378.7881, Loss: 0.1367
Saving weights
Batch 0/185, Time: 2.1009, Loss: 0.0526
Epoch 2/5, Time: 374.6736, Loss: 0.0491
Saving weights
Batch 0/185, Time: 2.0959, Loss: 0.0434
Epoch 3/5, Time: 375.2402, Loss: 0.0480
Saving weights
Batch 0/185, Time: 2.0961, Loss: 0.0529
Epoch 4/5, Time: 375.6489, Loss: 0.0484
Saving weights
Batch 0/185, Time: 2.1073, Loss: 0.0380
Epoch 5/5, Time: 375.2379, Loss: 0.0494
Saving weights


In [10]:
with open("UNet2.txt", "r") as f:
    content = f.read()

print(content)

FileNotFoundError: [Errno 2] No such file or directory: 'UNet2.txt'

In [None]:
plt.plot(training_losses)
print(training_losses)

In [None]:
plt.plot(epoch_times)
print(epoch_times)