# Brain Tumor MRI Segemtation using UNet

This project uses a Convolutional Neural Network (CNN) to classify MRI brain scan images into multiple tumor types. The model is trained using the **BraTS2020** dataset pre organised and converted to .npy files from .nii files

### Objective
To segment MRI images of the brain into:
- background
- edema,
- non-enhancing,
- enhancing tumor

We use **transfer learning with resnet18**

We train a seperate model for each *modularity available (flair, t1, t1ce,t2)* and then ensemble them to make pridictions.

In [1]:
import os, random
import torch
from dataset import BraTSDataset2D
from torch.utils.data import DataLoader
import random

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import segmentation_models_pytorch as smp
from torch.optim.lr_scheduler import CosineAnnealingLR

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


## Data Pipeline

- Dataset Class: BraTSDataset2D
- Mode: Slice-wise 2D training
- Preprocessing:
  - Normalization per slice
  - Label remapping (4 → 3)

In [4]:
def check_files(root_dir, modality):
  mod_folder = os.path.join(root_dir, modality)
  case_ids = [f.replace(".npy", "") for f in os.listdir(mod_folder)]

  random.seed(42)
  random.shuffle(case_ids)
  split_idx = int(0.8 * len(case_ids))
  train_ids = case_ids[:split_idx]
  val_ids   = case_ids[split_idx:]

  print(f"{modality} - Training: {len(train_ids)}, Validation: {len(val_ids)}")
  return train_ids, val_ids

In [5]:
def make_dataloaders(root_dir, train_ids, val_ids, modality):

  train_dataset = BraTSDataset2D(root_dir=root_dir, case_ids=train_ids, modality_name=modality)
  val_dataset   = BraTSDataset2D(root_dir=root_dir, case_ids=val_ids, modality_name=modality)

  train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
  val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

  print("DataLoaders ready!")
  return train_loader, val_loader

## Model
- **Architecture**: U-Net (2D)
- Input Channels: Single modality (e.g., flair)
- Output: Pixel-wise segmentation (4 classes)

- **Loss Function**: Dice Loss
- **Optimizer**: AdamW
- **Learning Rate Scheduler**:CosineAnnealingLR

In [7]:
in_channels = 1
classes = 4      # 0=background, 1=edema, 2=core, 3=enhancing

model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights=None,
    in_channels=in_channels,
    classes=classes
).to(device)

In [8]:
loss_fn = smp.losses.DiceLoss(mode="multiclass")
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

In [9]:
def train_val_loop(model, train_loader, val_loader):
  # -------------------------------
  # Training loop
  # -------------------------------
  num_epochs = 10
  metrics = []
  for epoch in range(1, num_epochs+1):
      model.train()
      train_loss = 0.0
      for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
          imgs, masks = imgs.to(device), masks.to(device)
          optimizer.zero_grad()
          outputs = model(imgs)
          loss = loss_fn(outputs, masks)
          loss.backward()
          optimizer.step()
          train_loss += loss.item() * imgs.size(0)

      scheduler.step()
      train_loss /= len(train_loader.dataset)
      print(f"Epoch {epoch} - Train Loss: {train_loss:.4f}")

      # -------------------------------
      # Validation loop
      # -------------------------------
      model.eval()
      val_loss = 0.0
      with torch.no_grad():
          for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
              imgs, masks = imgs.to(device), masks.to(device)
              outputs = model(imgs)
              loss = loss_fn(outputs, masks)
              val_loss += loss.item() * imgs.size(0)

      val_loss /= len(val_loader.dataset)
      print(f"Epoch {epoch} - Val Loss: {val_loss:.4f}")

      metrics.append([train_loss, val_loss])
  print("Training complete!")

  return model.state_dict(), metrics

## flair

In [10]:
root_dir = "../data/BraTS_2020_Train"

In [11]:
train_ids , val_ids = check_files(root_dir = root_dir ,modality = "flair")

flair - Training: 295, Validation: 74


In [13]:
train_loader, val_loader = make_dataloaders(root_dir = root_dir, train_ids = train_ids, val_ids = val_ids, modality = "flair")

DataLoaders ready!


In [14]:
weights , metrics = train_val_loop(model, train_loader, val_loader)

Epoch 1 [Train]:   2%|▏         | 50/2858 [03:17<3:04:41,  3.95s/it]


KeyboardInterrupt: 

In [None]:
torch.save(weights, f"./flair_model.pth")