<a href="https://colab.research.google.com/github/Abk0003/Brain_Tumor_Segmentation_BCP/blob/main/Week_3%264/BrainTumorSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset,DataLoader
from google.colab import files


def normalize_mri(mri):
    pmin  = np.percentile(mri, 1)
    pmax = np.percentile(mri, 99)

    if pmax == pmin:
        return np.zeros_like(mri, dtype=np.float32)

    mri = np.clip(mri,pmin,pmax)
    return (mri - pmin) / (pmax - pmin)


class MRILoader(Dataset):
    def __init__(self, min_nonzero_ratio=0.0):
        #UPLOAD FILE
        upload = files.upload()

        #FILENAMES OF MODALITIES
        flair_path = "flair.nii.gz"
        t1_path    = "t1.nii.gz"
        t1ce_path  = "t1ce.nii.gz"
        t2_path    = "t2.nii.gz"
        seg_path   = "seg.nii.gz"

        if None in [flair_path, t1_path, t1ce_path, t2_path, seg_path]:
            raise ValueError("Upload flair, t1, t1ce, t2, and seg NIfTI files.")

        #LOAD ALL MODALITIES AND SEGMENTATION MASK
        self.modalities = np.stack([
            nib.load(flair_path).get_fdata(),
            nib.load(t1_path).get_fdata(),
            nib.load(t1ce_path).get_fdata(),
            nib.load(t2_path).get_fdata()
        ])

        self.seg = nib.load(seg_path).get_fdata().astype(np.int64)

        self.slices = []
        for z in range(self.modalities.shape[-1]):
            self.slices.append(z)

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        z = self.slices[idx]

        img = np.stack([
            normalize_mri(self.modalities[c, :, :, z])
            for c in range(4)
        ])

        mask = self.seg[:, :, z]

        return (
            torch.from_numpy(img),
            torch.from_numpy(mask)
        )

dataset = MRILoader()
loader = DataLoader(dataset,batch_size=8,shuffle = True)

def conv_block(in_channels,out_channels,inplace):
  return nn.Sequential(
      nn.Conv2d(in_channels,out_channels,3,padding = 1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=inplace),
      nn.Conv2d(out_channels,out_channels,3,padding = 1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=inplace),
      nn.Conv2d(out_channels,out_channels,3,padding = 1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=inplace),
  )

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder1 = conv_block(4,64,True)
    self.encoder2 = conv_block(64,128,True)
    self.encoder3 = conv_block(128,256,True)
    self.bottleneck = conv_block(256,512,True)
    self.decoder3 = conv_block(512 + 256,256,False)
    self.decoder2 = conv_block(256 + 128,128,False)
    self.decoder1 = conv_block(128 + 64 ,64,False)
    self.out = nn.Conv2d(64, 4, 1)
  def forward(self,x):
    e1 = self.encoder1(x)
    e2 = self.encoder2(e1)
    e3 = self.encoder3(e2)
    b = self.bottleneck(e3)
    d3 = self.decoder3(b)
    d2 = self.decoder2(d3)
    d1 = self.decoder1(d2)
    return self.out(d1)





