<a href="https://colab.research.google.com/github/Abk0003/Brain_Tumor_Segmentation_BCP/blob/main/BrainTumorSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from google.colab import files



def normalize_mri(mri):
    eps = 1e-6
    pmin = np.percentile(mri, 1)
    pmax = np.percentile(mri, 99)
    if pmax == pmin:
        return np.zeros_like(mri, dtype=np.float32)
    mri = np.clip(mri, pmin, pmax)
    return ((mri - mri.mean()) / (mri.std() + eps)).astype(np.float32)


class MRILoader(Dataset):
    def __init__(self):
        files.upload()

        self.modalities = np.stack([
            nib.load("flair.nii").get_fdata(),
            nib.load("t1.nii").get_fdata(),
            nib.load("t1ce.nii").get_fdata(),
            nib.load("t2.nii").get_fdata()
        ])
        self.seg = nib.load("seg.nii").get_fdata().astype(np.int64)
        self.slices = list(range(self.modalities.shape[-1]))

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        z = self.slices[idx]
        img = np.stack([
            normalize_mri(self.modalities[c, :, :, z])
            for c in range(4)
        ])
        mask = self.seg[:, :, z]
        return torch.from_numpy(img).float(), torch.from_numpy(mask).long()


def conv_block(in_ch, out_ch, inplace):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=inplace),
        nn.Conv2d(out_ch, out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=inplace)
    )

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder1 = conv_block(4, 64, True)
        self.encoder2 = conv_block(64, 128, True)
        self.encoder3 = conv_block(128, 256, True)

        self.pool = nn.MaxPool2d(2)
        self.bottleneck = conv_block(256, 512, True)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)

        self.decoder3 = conv_block(512, 256, False)
        self.decoder2 = conv_block(256, 128, False)
        self.decoder1 = conv_block(128, 64, False)

        self.out = nn.Conv2d(64, 4, 1)

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))
        e3 = self.encoder3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))

        d3 = self.decoder3(torch.cat([self.up3(b), e3], 1))
        d2 = self.decoder2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.decoder1(torch.cat([self.up1(d2), e1], 1))
        return self.out(d1)


class LossCriterion(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()

    def forward(self, pred, target):
        ce_loss = self.ce(pred, target)

        pred = F.softmax(pred, dim=1)
        target_oh = F.one_hot(target, 4).permute(0,3,1,2).float()

        intersection = (pred * target_oh).sum((2,3))
        union = pred.sum((2,3)) + target_oh.sum((2,3))
        dice = (2 * intersection + 1e-6) / (union + 1e-6)

        dice = dice[:,1:].mean()
        return ce_loss + (1 - dice)


def validation_dice(model, loader, device):
    model.eval()
    dice_sum = torch.zeros(3, device=device)
    count = 0

    with torch.no_grad():
        for img, mask in loader:
            img, mask = img.to(device), mask.to(device)
            mask[mask == 4] = 3

            pred = model(img)
            pred = torch.argmax(pred, dim=1)

            for c in range(1, 4):
                p = (pred == c).float()
                g = (mask == c).float()
                d = (2 * (p * g).sum() + 1e-6) / (p.sum() + g.sum() + 1e-6)
                dice_sum[c-1] += d
            count += 1

    return dice_sum / count


def grad_norm(model):
    total = 0
    for p in model.parameters():
        if p.grad is not None:
            total += p.grad.data.norm(2).item() ** 2
    return total ** 0.5


dataset = MRILoader()
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
criterion = LossCriterion()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(300):
    model.train()
    loss_sum = ce_sum = dice_sum = grad_sum = 0

    for img, mask in train_loader:
        img, mask = img.to(device), mask.to(device)
        mask[mask == 4] = 3

        optimizer.zero_grad()
        pred = model(img)

        ce = criterion.ce(pred, mask)

        pred_soft = F.softmax(pred, dim=1)
        mask_oh = F.one_hot(mask, 4).permute(0,3,1,2).float()
        inter = (pred_soft * mask_oh).sum((2,3))
        union = pred_soft.sum((2,3)) + mask_oh.sum((2,3))
        dice_loss = 1 - (2 * inter + 1e-6) / (union + 1e-6)
        dice_loss = dice_loss[:,1:].mean()

        loss = ce + dice_loss
        loss.backward()

        g = grad_norm(model)
        optimizer.step()

        loss_sum += loss.item()
        ce_sum += ce.item()
        dice_sum += dice_loss.item()
        grad_sum += g

    val_d = validation_dice(model, val_loader, device)
    lr = optimizer.param_groups[0]["lr"]

    print(
        f"Epoch {epoch+1:03d} ; "
        f"Loss {loss_sum/len(train_loader):.4f} ; "
        f"CE {ce_sum/len(train_loader):.4f} ; "
        f"DiceLoss {dice_sum/len(train_loader):.4f} ; "
        f"Grad {grad_sum/len(train_loader):.2e} | "
    )

torch.save(model.state_dict(), "unet_mri_epoch104.pth")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

files.upload()

model = CNN().to(device)
model.load_state_dict(torch.load("unet_mri_epoch104.pth", map_location=device))
model.eval()

flair = nib.load("flair.nii").get_fdata()
t1    = nib.load("t1.nii").get_fdata()
t1ce  = nib.load("t1ce.nii").get_fdata()
t2    = nib.load("t2.nii").get_fdata()

modalities = np.stack([flair, t1, t1ce, t2])  # (4, H, W, Z)

H, W, Z = flair.shape
prediction = np.zeros((H, W, Z), dtype=np.uint8)

with torch.no_grad():
    for z in range(Z):
        img = np.stack([
            normalize_mri(modalities[c, :, :, z])
            for c in range(4)
        ])

        img = torch.from_numpy(img).unsqueeze(0).to(device)

        logits = model(img)
        seg = torch.argmax(logits, dim=1).squeeze(0)

        prediction[:, :, z] = seg.cpu().numpy()

pred_nii = nib.Nifti1Image(prediction, affine=nib.load("flair.nii").affine)
nib.save(pred_nii, "prediction.nii")

files.download("prediction.nii")







KeyboardInterrupt: 