In [None]:
from google.colab import files
uploaded = files.upload()

Saving brats_final_split.zip to brats_final_split.zip


In [None]:
import zipfile

zip_path = list(uploaded.keys())[0]
with zipfile.ZipFile(zip_path, 'r') as z:
    z.extractall("/content/data")

DATA_ROOT = "/content/data/brats_final_split"
print("Dataset extracted to:", DATA_ROOT)

Dataset extracted to: /content/data/brats_final_split


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [4]:
COLOR_TO_LABEL = {
    (0, 0, 0): 0,        # Background
    (0, 0, 255): 1,      # CSF
    (0, 255, 0): 2,      # Gray Matter
    (255, 255, 0): 3,    # White Matter
    (255, 0, 0): 4       # Tumor
}

In [5]:
def rgb_to_label(mask_rgb):
    h, w, _ = mask_rgb.shape
    label = np.zeros((h, w), dtype=np.int64)
    for rgb, cls in COLOR_TO_LABEL.items():
        label[(mask_rgb == rgb).all(axis=-1)] = cls
    return label

In [6]:
class BratsFlairSliceDataset(Dataset):
    def __init__(self, root, split="train"):
        self.items = []

        flair_dir = os.path.join(root, split, "flair")
        mask_dir  = os.path.join(root, split, "mask")

        flair_files = sorted(os.listdir(flair_dir))

        for fname in flair_files:
            if fname.endswith(".png"):
                img_path = os.path.join(flair_dir, fname)
                mask_path = os.path.join(mask_dir, fname)

                if os.path.exists(mask_path):
                    self.items.append((img_path, mask_path))

    def __getitem__(self, idx):
        img_path, mask_path = self.items[idx]

        # Flair (1-channel grayscale)
        img = Image.open(img_path).convert("L")
        img = np.array(img, dtype=np.float32) / 255.0
        img = torch.from_numpy(img).unsqueeze(0)

        # Mask (RGB → class indices 0–4)
        mask_rgb = np.array(Image.open(mask_path))
        mask = rgb_to_label(mask_rgb)
        mask = torch.from_numpy(mask).long()

        return img, mask

    def __len__(self):
        return len(self.items)

In [7]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        return self.conv(self.pool(x))


class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # Pad to match shape
        diff_y = x2.size(2) - x1.size(2)
        diff_x = x2.size(3) - x1.size(3)

        x1 = F.pad(x1, [
            diff_x // 2, diff_x - diff_x // 2,
            diff_y // 2, diff_y - diff_y // 2
        ])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNetSeg(nn.Module):
    def __init__(self, in_channels=1, num_classes=5, base_ch=64):
        super().__init__()

        self.inc = DoubleConv(in_channels, base_ch)
        self.down1 = Down(base_ch, base_ch*2)
        self.down2 = Down(base_ch*2, base_ch*4)
        self.down3 = Down(base_ch*4, base_ch*8)
        self.down4 = Down(base_ch*8, base_ch*16)

        self.up1 = Up(base_ch*16, base_ch*8)
        self.up2 = Up(base_ch*8, base_ch*4)
        self.up3 = Up(base_ch*4, base_ch*2)
        self.up4 = Up(base_ch*2, base_ch)

        self.outc = nn.Conv2d(base_ch, num_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        return self.outc(x)

In [8]:
def dice_multiclass(pred, target, num_classes=5):
    pred = pred.argmax(dim=1)

    dice_scores = []

    for c in range(num_classes):
        p = (pred == c).float()
        t = (target == c).float()

        inter = (p * t).sum()
        union = p.sum() + t.sum()

        dice = (2 * inter + 1e-6) / (union + 1e-6)
        dice_scores.append(dice)

    return sum(dice_scores) / num_classes

In [12]:
def extract_encoder_features(model, x):

    # Extracts final encoder feature map from the U-Net.

    x1 = model.inc(x)
    x2 = model.down1(x1)
    x3 = model.down2(x2)
    x4 = model.down3(x3)
    x5 = model.down4(x4)  # final encoder features

    return x5

In [11]:
# defining device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNetSeg(in_channels=1, num_classes=5).to(device)
dummy = torch.randn(1, 1, 240, 240).to(device)

feat = extract_encoder_features(model, dummy)
print("Feature shape:", feat.shape)

Feature shape: torch.Size([1, 1024, 15, 15])


In [None]:
model = UNetSeg(in_channels=1, num_classes=5).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
train_ds = BratsFlairSliceDataset(DATA_ROOT, "train")
val_ds   = BratsFlairSliceDataset(DATA_ROOT, "val")

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=8, shuffle=False)

In [None]:
EPOCHS = 19
best_dice = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    print(f"Epoch {epoch} Train Loss: {train_loss / len(train_loader):.4f}")

    # Validation
    model.eval()
    val_d = 0

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            val_d += dice_multiclass(preds, masks).item()

    val_d /= len(val_loader)
    print(f"Epoch {epoch} Val Dice: {val_d:.4f}")

    if val_d > best_dice:
        best_dice = val_d
        torch.save(model.state_dict(), "unet_best_5class.pth")
        print("Saved best model!")

Epoch 1 Train Loss: 0.1506
Epoch 1 Val Dice: 0.6324
Saved best model!
Epoch 2 Train Loss: 0.1091
Epoch 2 Val Dice: 0.6316
Epoch 3 Train Loss: 0.1007
Epoch 3 Val Dice: 0.6400
Saved best model!
Epoch 4 Train Loss: 0.0952
Epoch 4 Val Dice: 0.6642
Saved best model!
Epoch 5 Train Loss: 0.0904
Epoch 5 Val Dice: 0.6749
Saved best model!
Epoch 6 Train Loss: 0.0840
Epoch 6 Val Dice: 0.6764
Saved best model!
Epoch 7 Train Loss: 0.0784
Epoch 7 Val Dice: 0.6419
Epoch 8 Train Loss: 0.0725
Epoch 8 Val Dice: 0.6753
Epoch 9 Train Loss: 0.0662
Epoch 9 Val Dice: 0.6768
Saved best model!
Epoch 10 Train Loss: 0.0607
Epoch 10 Val Dice: 0.6821
Saved best model!
Epoch 11 Train Loss: 0.0561
Epoch 11 Val Dice: 0.6734
Epoch 12 Train Loss: 0.0522
Epoch 12 Val Dice: 0.6819
Epoch 13 Train Loss: 0.0489
Epoch 13 Val Dice: 0.6840
Saved best model!
Epoch 14 Train Loss: 0.0460
Epoch 14 Val Dice: 0.6672
Epoch 15 Train Loss: 0.0432
Epoch 15 Val Dice: 0.6826
Epoch 16 Train Loss: 0.0411
Epoch 16 Val Dice: 0.6743
Epoch 17 T

In [None]:
# Load best weights from previous 19 epochs
model.load_state_dict(torch.load("/content/unet_best_5class.pth"))
print("Loaded best checkpoint from epoch 1–19")

START_EPOCH = 19
END_EPOCH = 45   # train until epoch 45

best_dice = 0.6894

for epoch in range(START_EPOCH + 1, END_EPOCH + 1):
    model.train()
    train_loss = 0

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    print(f"Epoch {epoch} Train Loss: {train_loss / len(train_loader):.4f}")

    # Validation
    model.eval()
    val_d = 0

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            val_d += dice_multiclass(preds, masks).item()

    val_d /= len(val_loader)
    print(f"Epoch {epoch} Val Dice: {val_d:.4f}")

    if val_d > best_dice:
        best_dice = val_d
        torch.save(model.state_dict(), "unet_best_5class.pth")
        print("Saved best model!")

Loaded best checkpoint from epoch 1–19
Epoch 20 Train Loss: 0.0356
Epoch 20 Val Dice: 0.6753
Epoch 21 Train Loss: 0.0336
Epoch 21 Val Dice: 0.6876
Epoch 22 Train Loss: 0.0325
Epoch 22 Val Dice: 0.6738
Epoch 23 Train Loss: 0.0310
Epoch 23 Val Dice: 0.6500
Epoch 24 Train Loss: 0.0296
Epoch 24 Val Dice: 0.6789
Epoch 25 Train Loss: 0.0282
Epoch 25 Val Dice: 0.6762
Epoch 26 Train Loss: 0.0271
Epoch 26 Val Dice: 0.6721
Epoch 27 Train Loss: 0.0264
Epoch 27 Val Dice: 0.6629
Epoch 28 Train Loss: 0.0251
Epoch 28 Val Dice: 0.6771
Epoch 29 Train Loss: 0.0242
Epoch 29 Val Dice: 0.6829
Epoch 30 Train Loss: 0.0236
Epoch 30 Val Dice: 0.6737
Epoch 31 Train Loss: 0.0227
Epoch 31 Val Dice: 0.6563
Epoch 32 Train Loss: 0.0221
Epoch 32 Val Dice: 0.6695
Epoch 33 Train Loss: 0.0214
Epoch 33 Val Dice: 0.6787
Epoch 34 Train Loss: 0.0208
Epoch 34 Val Dice: 0.6807
Epoch 35 Train Loss: 0.0202
Epoch 35 Val Dice: 0.6723
Epoch 36 Train Loss: 0.0197
Epoch 36 Val Dice: 0.6698
Epoch 37 Train Loss: 0.0193
Epoch 37 Val Di

In [None]:
from google.colab import files
files.download("unet_best_5class.pth")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>