In [1]:
import kagglehub
dataset_path = kagglehub.dataset_download("awsaf49/brats20-dataset-training-validation")
print("Path to dataset files:", dataset_path)

Path to dataset files: /kaggle/input/brats20-dataset-training-validation


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2 as cv
import os
import time
import gc

import nibabel as nib

from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

print("Random seed set.")

Random seed set.


In [3]:
class MRISegDataset(Dataset):
    def __init__(self, base_path, transform=None):
        """
        Args:
            base_path (str): Path to the dataset folder.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.base_path = base_path
        self.patients = [p for p in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, p))]
        self.transform = transform
        
    def __len__(self):
        return len(self.patients)*155

    def __getitem__(self, ind):
        """
        Args:
            idx (int): Index of the patient to fetch data for.

        Returns:
            dict: A dictionary containing 'image', 'segmentation', and optionally 'patient_id'.
        """
        patient_id = self.patients[ind//155]
        if patient_id == 'BraTS20_Training_355': 
            flair_path = os.path.join(self.base_path, patient_id, "W39_1998.09.19_Segm.nii")
            seg_path = os.path.join(self.base_path, patient_id, "W39_1998.09.19_Segm.nii")
        else:
            flair_path = os.path.join(self.base_path, patient_id, f"{patient_id}_flair.nii")
            seg_path = os.path.join(self.base_path, patient_id, f"{patient_id}_seg.nii")

        flair = nib.load(flair_path).get_fdata()[:, :, ind%155].reshape(1, 240, 240)
        seg = nib.load(seg_path).get_fdata()[:, :, ind%155].reshape(1, 240, 240)
        seg = np.where(seg==4, 3, seg)

        # Normalization
        flair = (flair - 0) / (900 - 0)

        flair = torch.tensor(flair, dtype=torch.float32)
        seg = torch.tensor(seg, dtype=torch.long)
        

        if self.transform:
            flair = self.transform(flair)
            seg = self.transform(seg)

        sample = {'image': flair, 'segmentation': seg, 'patient_id': patient_id}

        return sample

In [4]:
base_path = dataset_path +'/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
brats_dataset = MRISegDataset(base_path=base_path, transform=transforms.Resize((256, 256)))

# Example: Access a single sample
sample = brats_dataset[0]
print("Image shape:", sample['image'].shape)
print("Segmentation shape:", sample['segmentation'].shape)
print("Patient ID:", sample['patient_id'])

Image shape: torch.Size([1, 256, 256])
Segmentation shape: torch.Size([1, 256, 256])
Patient ID: BraTS20_Training_083


In [5]:
def dice_coefficient(pred, target, smooth=1e-6):
    pred = pred.contiguous()
    target = target.contiguous()

    pred_class_all = torch.argmax(pred, dim=1)
    
    dice_total = []
    for class_ind in range(4):
        pred_class = (pred_class_all == class_ind).float()
        target_class = (target == class_ind).float()
        
        intersection = (pred_class * target_class).sum(dim=(1,2))
        union = pred_class.sum(dim=(1,2)) + target_class.sum(dim=(1,2))
        
        dice_class = (2. * intersection + smooth) / (union + smooth)
        dice_total.append(dice_class)
    
    return torch.mean(torch.stack(dice_total))

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Currently using: {device}')

Currently using: cuda


In [7]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)  # Match dimensions

    def forward(self, x):
        identity = self.skip(x)
        out = torch.relu(self.conv1(x))
        out = self.conv2(out)
        out += identity # Reisdual connection
        out = torch.relu(out)
        return out


class UNet_Res(nn.Module):
    def __init__(self):
        super(UNet_Res, self).__init__()

        self.enc1 = ResBlock(1, 16)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = ResBlock(16, 32)
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = ResBlock(32, 32)

        self.up2 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2)
        self.dec2 = ResBlock(64, 32)

        self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.dec1 = ResBlock(32, 16)

        self.final_conv = nn.Conv2d(16, 4, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc1_pool = self.pool1(enc1)

        enc2 = self.enc2(enc1_pool)
        enc2_pool = self.pool2(enc2)

        bottleneck = self.bottleneck(enc2_pool)

        up2 = self.up2(bottleneck)
        dec2 = self.dec2(torch.cat([up2, enc2], dim=1))

        up1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat([up1, enc1], dim=1))

        out = self.final_conv(dec1)
        return out


tmp_model = UNet_Res()

print(tmp_model)
print(f"Total trainable parameters: {sum(p.numel() for p in tmp_model.parameters() if p.requires_grad):,}")

UNet_Res(
  (enc1): ResBlock(
    (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (skip): Conv2d(1, 16, kernel_size=(1, 1), stride=(1, 1))
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (enc2): ResBlock(
    (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (skip): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bottleneck): ResBlock(
    (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (skip): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (up2): ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(2, 2))
  (dec2): ResB

In [8]:
batch_size = 16
data_loader = DataLoader(brats_dataset, batch_size=batch_size, shuffle=True, pin_memory=True,
                         num_workers=2)
print(f"Number of Batches: {len(data_loader)}")

model_res = UNet_Res()
if torch.cuda.is_available(): 
    model_res = model_res.cuda()
    print("Moved to cuda.")
criterion_res = nn.CrossEntropyLoss()
optimizer_res = torch.optim.Adam(model_res.parameters(), lr=1e-3)

Number of Batches: 3575
Moved to cuda.


In [9]:
num_epochs = 5
training_losses = []
epoch_times = []

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    model_res.train()
    running_loss = 0.0

    for batch_idx, batch in enumerate(data_loader):
        batch_start_time = time.time()
        images = batch['image'].to(device) 
        segmentations = batch['segmentation'].squeeze(1).to(device)

        optimizer_res.zero_grad()
    
        outputs = model_res(images)
        loss = criterion_res(outputs, segmentations)
        
        loss.backward()
        optimizer_res.step()

        running_loss += loss.item()
        
        del images
        del segmentations
        torch.cuda.empty_cache()
        gc.collect()
        
        if batch_idx%300 == 0:
            print(f"Batch {batch_idx}/{len(data_loader)}, Time: {time.time()-batch_start_time:.4f}, Loss: {loss.item():.4f}")
        
    avg_loss = running_loss / len(data_loader)
    epoch_time = time.time()-epoch_start_time
    print(f"Epoch {epoch+1}/{num_epochs}, Time: {epoch_time:.4f}, Loss: {avg_loss:.4f}")

    print("Saving weights")
    torch.save(model_res.state_dict(), f"UNet3_weights_{epoch}.pth")
    training_losses.append(avg_loss)
    epoch_times.append(epoch_time)
    
with open("UNet3.txt", "w") as f:
    f.write("Training_losses: \n")
    for i in training_losses: f.write(str(i)+"\n")
    f.write("\nEpoch_times: \n")
    for i in epoch_times: f.write(str(i)+"\n")

Batch 0/3575, Time: 1.3007, Loss: 1.3038
Batch 300/3575, Time: 0.2631, Loss: 0.0381
Batch 600/3575, Time: 0.3565, Loss: 0.0603
Batch 900/3575, Time: 0.2623, Loss: 0.0437
Batch 1200/3575, Time: 0.3594, Loss: 0.0540
Batch 1500/3575, Time: 0.2623, Loss: 0.0625
Batch 1800/3575, Time: 0.3297, Loss: 0.0189
Batch 2100/3575, Time: 0.3563, Loss: 0.0305
Batch 2400/3575, Time: 0.2605, Loss: 0.0394
Batch 2700/3575, Time: 0.3653, Loss: 0.0437
Batch 3000/3575, Time: 0.3495, Loss: 0.0189
Batch 3300/3575, Time: 0.2634, Loss: 0.0739
Epoch 1/5, Time: 1705.4893, Loss: 0.0484
Saving weights
Batch 0/3575, Time: 0.3759, Loss: 0.0680
Batch 300/3575, Time: 0.3739, Loss: 0.0080
Batch 600/3575, Time: 0.3719, Loss: 0.0357
Batch 900/3575, Time: 0.3745, Loss: 0.0274
Batch 1200/3575, Time: 0.3608, Loss: 0.0291
Batch 1500/3575, Time: 0.3654, Loss: 0.0247
Batch 1800/3575, Time: 0.3700, Loss: 0.0408
Batch 2100/3575, Time: 0.3205, Loss: 0.0485
Batch 2400/3575, Time: 0.2669, Loss: 0.0259
Batch 2700/3575, Time: 0.2920, L