# Brain Tumour Segmentation Network
## Import packages
Please make sure you have all the required packages installed for this project.

# Brain Tumour Segmentation Network
## Import packages
Please make sure you have all the required packages installed for this project.

# Brain Tumour Segmentation Network
## Import packages
Please make sure you have all the required packages installed for this project.

In [1]:
# Imports
import os
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

## Visualise MRI Volume Slices and Segmentation Maps
Each MRI image contains information about a three-dimensional (3D) volume of space. An MRI image is composed of a number of voxels, which are like pixels in 2D images. Here, try to visualise the axial plane (which usually has higher resolution) of some of the volumes and the corresponding segmentation maps.

In [2]:

# ⚙️ Step: Skull Stripping (Simplified Example)
# This is a dummy skull stripping using simple thresholding.
# For real use, consider using brain extraction tools like FSL's BET.

import numpy as np
import cv2

def skull_strip(image):
    # Simple threshold to isolate brain region
    _, mask = cv2.threshold(image, 30, 255, cv2.THRESH_BINARY)
    result = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
    return result

# Example:
# img = cv2.imread("sample_image.png", 0)
# stripped = skull_strip(img)
# plt.imshow(stripped, cmap='gray')


In [3]:

# ⚙️ Step: Image Registration using SimpleITK (rigid registration example)

import SimpleITK as sitk

def register_images(fixed_image_np, moving_image_np):
    # Convert NumPy to SimpleITK images
    fixed = sitk.GetImageFromArray(fixed_image_np.astype(np.float32))
    moving = sitk.GetImageFromArray(moving_image_np.astype(np.float32))

    # Registration setup
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMeanSquares()
    registration_method.SetOptimizerAsRegularStepGradientDescent(1.0, 1e-6, 200)
    registration_method.SetInitialTransform(sitk.TranslationTransform(fixed.GetDimension()))
    registration_method.SetInterpolator(sitk.sitkLinear)

    # Execute registration
    transform = registration_method.Execute(fixed, moving)
    resampled = sitk.Resample(moving, fixed, transform, sitk.sitkLinear, 0.0, moving.GetPixelID())

    return sitk.GetArrayFromImage(resampled)

# Example usage:
# fixed = cv2.imread("fixed.png", 0)
# moving = cv2.imread("moving.png", 0)
# registered = register_images(fixed, moving)
# plt.imshow(registered, cmap='gray')


In [4]:

# ⚙️ Step: Resampling to uniform voxel spacing using SimpleITK

def resample_image(image_np, new_spacing=(1.0, 1.0)):
    image = sitk.GetImageFromArray(image_np.astype(np.float32))
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()

    # Compute new size
    new_size = [
        int(round(osz * ospc / nspc))
        for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
    ]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(new_spacing)
    resample.SetSize(new_size)
    resample.SetInterpolator(sitk.sitkLinear)
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetOutputDirection(image.GetDirection())

    resampled = resample.Execute(image)
    return sitk.GetArrayFromImage(resampled)

# Example:
# img = cv2.imread("input.png", 0)
# resampled_img = resample_image(img)


In [5]:
# Visualize a single MRI slice and corresponding segmentation map
def show_mri_slice(mri_path, mask_path=None, slice_index=50):
    mri = nib.load(mri_path).get_fdata()
    slice_img = mri[:, :, slice_index]
    
    plt.figure(figsize=(10,5))
    plt.subplot(1, 2 if mask_path else 1, 1)
    plt.imshow(slice_img.T, cmap='gray', origin='lower')
    plt.title('MRI Slice')
    
    if mask_path:
        mask = nib.load(mask_path).get_fdata()
        slice_mask = mask[:, :, slice_index]
        plt.subplot(1, 2, 2)
        plt.imshow(slice_img.T, cmap='gray', origin='lower')
        plt.imshow(slice_mask.T, cmap='Reds', alpha=0.5, origin='lower')
        plt.title('With Mask Overlay')
    plt.show()

## Data preprocessing (Optional)

Images in the original dataset are usually of different sizes, so sometimes we need to resize and normalise them (z-score is commonly used in preprocessing MRI images) to fit the CNN model. Depending on the images you choose to use for training your model, you may need to apply other preprocessing methods. If preprocessing methods like cropping are applied, remember to convert the segmentation result back to its original size.

In [6]:
# Optional: preprocessing function (not required if using Dataset)
def normalize_slice(slice):
    return (slice - np.mean(slice)) / np.std(slice)

## Train-time data augmentation
Generalizability is crucial to a deep learning model, and it refers to the performance difference of a model when evaluated on seen data (training data) versus unseen data (testing data). Improving the generalizability of these models has always been a difficult challenge. 

**Data Augmentation** is an effective way of improving generalizability, because the augmented data will represent a more comprehensive set of possible data samples and minimize the distance between the training and validation/testing sets.

There are many data augmentation methods you can choose from in this project, including rotation, shifting, flipping, etc. PyTorch provides excellent data augmentation capabilities through torchvision.transforms, which you can combine with custom transforms for medical imaging.

You are encouraged to try different augmentation methods to get the best segmentation result.


## Get the data generator ready

In [7]:
# Define any transforms (optional)
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15)
])

## Define a metric for the performance of the model
The Dice score is used here to evaluate the performance of your model.
More details about the Dice score and other metrics can be found at 
https://towardsdatascience.com/metrics-to-evaluate-your-semantic-segmentation-model-6bcb99639aa2. The Dice score can also be used as the loss function for training your model.

In [8]:
# Dice coefficient metric
def dice_coef(pred, target, epsilon=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()
    intersection = (pred * target).sum()
    return (2. * intersection) / (pred.sum() + target.sum() + epsilon)

## Build your own model here
The U-Net (https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28) structure is widely used for medical image segmentation tasks. You can build your own model or modify the U-Net by changing the hyperparameters for our task. If you choose to use PyTorch, more information about PyTorch layers, including Conv2d, MaxPool2d, and Dropout, can be found at https://pytorch.org/docs/stable/nn.html. You can also explore popular PyTorch implementations of U-Net for medical image segmentation.

In [9]:
# Basic U-Net model for 2D segmentation
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        def CBR(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
        self.enc1 = CBR(1, 64)
        self.enc2 = CBR(64, 128)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = CBR(128, 64)
        self.final = nn.Conv2d(64, 1, 1)
    
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        d1 = self.dec1(self.up(e2))
        out = self.final(d1)
        return torch.sigmoid(out)

## Train your model here
Once you have defined the model and data generator, you can start training your model. In PyTorch, you'll need to set up the training loop with the optimizer, loss function, and implement forward and backward passes manually.

In [10]:
# Dataset class using real files
class BrainMRISegmentationDataset(Dataset):
    def __init__(self, root_dir, slice_index=50, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.slice_index = slice_index
        self.samples = []

        for case in os.listdir(root_dir):
            case_path = os.path.join(root_dir, case)
            if not os.path.isdir(case_path):
                continue
            flair_path = os.path.join(case_path, f"{case}_fla.nii")
            seg_path = os.path.join(case_path, f"{case}_seg.nii")
            if os.path.exists(flair_path) and os.path.exists(seg_path):
                self.samples.append((flair_path, seg_path))

        if len(self.samples) == 0:
            raise RuntimeError(f"No valid cases found in {root_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        flair_path, seg_path = self.samples[idx]
        flair_volume = nib.load(flair_path).get_fdata()
        seg_volume = nib.load(seg_path).get_fdata()

        flair_slice = flair_volume[:, :, self.slice_index]
        seg_slice = seg_volume[:, :, self.slice_index]
        flair_slice = (flair_slice - np.mean(flair_slice)) / np.std(flair_slice)

        image = torch.tensor(flair_slice).unsqueeze(0).float()
        mask = torch.tensor(seg_slice).unsqueeze(0).float()

        if self.transform:
            image = self.transform(image)

        return image, mask

# Create dataset and dataloader
dataset = BrainMRISegmentationDataset('dataset_segmentation/train', slice_index=50, transform=None)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Train
for epoch in range(3):
    model.train()
    total_loss, total_dice = 0, 0
    for imgs, masks in tqdm(dataloader):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_dice += dice_coef(preds, masks).item()
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f} | Dice: {total_dice/len(dataloader):.4f}")

RuntimeError: No valid cases found in dataset_segmentation/train

## Save the model
Once your model is trained, remember to save it for testing. In PyTorch, you can save the model state dictionary using torch.save() for later loading and inference.

In [None]:
# Save model
torch.save(model.state_dict(), "brain_tumour_segmentation.pth")

## Run the model on the test set
After your last Q&A session, you will be given the test set. Run your model on the test set to get the segmentation results and submit your results in a .zip file. If the MRI image is named '100_fla.nii.gz', save your segmentation result as '100_seg.nii.gz'. 

In [None]:
# Inference and test loop placeholder
model.eval()
with torch.no_grad():
    for imgs, masks in dataloader:
        imgs = imgs.to(device)
        preds = model(imgs)
        print("Pred shape:", preds.shape)
        break  # only run one batch
