<a href="https://colab.research.google.com/github/SinghAnkit1010/Brain-Tumor-Segmentation/blob/main/Brats.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install nibabel
!pip install celluloid
!pip install pytorch_lightning
!pip install torchio
!pip install monai

# **Import Libraries**

In [None]:
import torch
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import os
import torchio as tio
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from celluloid import Camera
from IPython.display import HTML
from monai import metrics

# **Data Visualization**

In [None]:
image_path = "/content/drive/MyDrive/Task01_BrainTumour/imagesTr/BRATS_004.nii.gz"
mask_path = "/content/drive/MyDrive/Task01_BrainTumour/labelsTr/BRATS_004.nii.gz"

In [None]:
image = nib.load(image_path).get_fdata()
label = nib.load(mask_path).get_fdata().astype(np.uint8)

In [None]:
nib.aff2axcodes(nib.load(image_path).affine)

('R', 'A', 'S')

In [None]:
image.shape,label.shape

((240, 240, 155, 4), (240, 240, 155))

In [None]:
{"modality": {
	 "0": "FLAIR",
	 "1": "T1w",
	 "2": "t1gd",
	 "3": "T2w"
 } ,
"labels": {
	 "0": "background",
	 "1": "edema",
	 "2": "non-enhancing tumor",
	 "3": "enhancing tumour"
 }}

{'modality': {'0': 'FLAIR', '1': 'T1w', '2': 't1gd', '3': 'T2w'},
 'labels': {'0': 'background',
  '1': 'edema',
  '2': 'non-enhancing tumor',
  '3': 'enhancing tumour'}}

In [None]:
FLAIR = image[:, :, :, 0]
FLAIR = FLAIR[52:180, 42:170, 12:140]
label = label[52:180, 42:170, 12:140]

In [None]:
fig,ax = plt.subplots()
camera = Camera(fig)
label_names = ['background','edema','non-enhancing tumor','enhancing tumour']
clim = [0, len(label_names) - 1]
for i in range(FLAIR.shape[2]):
  im = plt.imshow(FLAIR[:,:,i],cmap = 'gray',)
  im_labels = plt.imshow(label[:,:,i],alpha = 0.5)
  im_labels.set_clim(clim)
  camera.snap()
cbar = plt.colorbar(im_labels, ax=ax,ticks = range(len(label_names)))
cbar.set_ticklabels(label_names)
cbar.set_label('Label')
animation = camera.animate()

In [None]:
FLAIR.shape

(128, 128, 128)

In [None]:
HTML(animation.to_html5_video())

# **Data Preparation**

In [None]:
image_directory = '/content/drive/MyDrive/Task01_BrainTumour/imagesTr'
label_directory = '/content/drive/MyDrive/Task01_BrainTumour/labelsTr'

In [None]:
image_filepath = []
label_filepath = []
for file_name in os.listdir(image_directory):
  if(file_name.startswith('BRATS')):
    image_filepath.append(os.path.join(image_directory,file_name))
    label_filepath.append(os.path.join(label_directory,file_name))

In [None]:
len(image_filepath),len(label_filepath)

(484, 484)

In [None]:
subjects = []
for i in range(len(image_filepath)):
  subject = tio.Subject({'MRI':tio.ScalarImage(image_filepath[i]),'label':tio.LabelMap(label_filepath[i])})
  subjects.append(subject)

In [None]:
transformation = tio.Compose([tio.Lambda(lambda x: x[0:1, :, :,:]),
                              tio.Crop((52, 60, 42, 70, 12, 15)),
                              tio.ZNormalization()])

augmentation = tio.Compose( [tio.RandomFlip(p = 0.25),
                            tio.RandomBlur(std = (0.5,1.5))])

In [None]:
val_transform = transformation
train_transform = tio.Compose([transformation,augmentation])

In [None]:
train_dataset = tio.SubjectsDataset(subjects[:350],transform = train_transform)
val_dataset = tio.SubjectsDataset(subjects[350:400],transform = val_transform)

In [None]:
sampler = tio.LabelSampler(patch_size = 64,label_name = 'label',label_probabilities={0:0.1,1:0.1,2:0.4,3:0.4})

In [None]:
train_pathches_queue = tio.Queue(train_dataset,max_length = 30,samples_per_volume = 5,sampler = sampler)
val_patches_queue = tio.Queue(val_dataset,max_length = 30,samples_per_volume = 5,sampler = sampler)

In [None]:
train_loader = torch.utils.data.DataLoader(train_pathches_queue,batch_size = 8)
validation_loader = torch.utils.data.DataLoader(val_patches_queue,batch_size = 8)

In [None]:
for data in train_loader:
  x = data['MRI']['data']
  y = data['label']['data']
  print(x.shape)
  print(y.shape)
  print(y[0].max())
  break

torch.Size([8, 1, 64, 64, 64])
torch.Size([8, 1, 64, 64, 64])
tensor(3, dtype=torch.uint8)


#**Model Creation**

In [None]:
class ConvBlock(torch.nn.Module):
  def __init__(self,in_channel,out_channel):
    super().__init__()
    self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,padding=1),
                                    torch.nn.BatchNorm3d(out_channel),
                                    torch.nn.ReLU(),
                                    torch.nn.Conv3d(in_channels=out_channel,out_channels=out_channel,kernel_size=3,padding=1),
                                    torch.nn.BatchNorm3d(out_channel),
                                    torch.nn.ReLU())
  def forward(self,x):
    return self.step(x)

In [None]:
class UNet(torch.nn.Module):
  def __init__(self,in_channel,out_channel):
    super().__init__()
    self.conv1 = ConvBlock(in_channel,32)
    self.conv2 = ConvBlock(32,64)
    self.conv3 = ConvBlock(64,128)

    self.conv4 = ConvBlock(128,256)

    self.deconv1 = ConvBlock(256,128)
    self.deconv2 = ConvBlock(128,64)
    self.deconv3 = ConvBlock(64,32)

    self.output = torch.nn.Conv3d(in_channels=32,out_channels=out_channel,kernel_size = 1)

    self.maxpool = torch.nn.MaxPool3d(2)

  def forward(self,x):
    x1 = self.conv1(x)
    x1m = self.maxpool(x1)

    x2 = self.conv2(x1m)
    x2m = self.maxpool(x2)

    x3 = self.conv3(x2m)
    x3m = self.maxpool(x3)


    encoder_output = self.conv4(x3m)


    y3 = torch.nn.ConvTranspose3d(in_channels=256,out_channels = 128,kernel_size = 2,stride = 2).cuda()(encoder_output)
    y3 = torch.cat([x3,y3],dim=1)
    y3 = self.deconv1(y3)

    y2 = torch.nn.ConvTranspose3d(in_channels = 128,out_channels = 64,kernel_size=2,stride = 2).cuda()(y3)
    y2 = torch.cat([x2,y2],dim=1)
    y2 = self.deconv2(y2)

    y1 = torch.nn.ConvTranspose3d(in_channels = 64,out_channels = 32,kernel_size = 2,stride = 2).cuda()(y2)
    y1 = torch.cat([x1,y1],dim=1)
    y1 = self.deconv3(y1)
    output = self.output(y1)
    return output

In [None]:
class DiceLoss(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self,pred,actual):
    pred = torch.flatten(pred)
    actual = torch.flatten(actual)

    counter = (actual * pred).sum()
    numer = 2 * counter
    denum = actual.sum() + pred.sum() + 1e-8
    dice = numer / denum
    return 1-dice

# **Model Training**

In [None]:
class segmentation_model(pl.LightningModule):
  def __init__(self):
    super().__init__()

    self.model = UNet(1,4)
    self.model.cuda()
    self.optimizer = torch.optim.Adam(self.model.parameters(),lr=1e-4)
    self.loss_function = torch.nn.CrossEntropyLoss()

  def forward(self,data):
    return self.model(data)

  def training_step(self,batch,batch_idx):
    x= batch['MRI']['data']
    y = batch['label']['data'][:,0]
    y = y.long()

    pred = self(x)
    loss = self.loss_function(pred,y)
    self.log('Train Dice',loss)
    return loss

  def validation_step(self,batch,batch_idx):
    x = batch['MRI']['data']
    y = batch['label']['data'][:,0]
    y = y.long()

    pred = self(x)
    loss = self.loss_function(pred,y)
    self.log('val Dice',loss)
    return loss

  def configure_optimizers(self):
    return [self.optimizer]

In [None]:
model = segmentation_model()

In [None]:
model.cuda()

In [None]:
checkpoint_callback = ModelCheckpoint(monitor = 'val Dice',save_top_k=5,mode = 'min')

In [None]:
trainer = pl.Trainer(logger =TensorBoardLogger(save_dir = 'logs'),log_every_n_steps=1 ,callbacks = checkpoint_callback,max_epochs = 25)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model,train_loader,validation_loader)

# **Model Evaluation**

In [None]:
model = segmentation_model.load_from_checkpoint('/content/drive/MyDrive/epoch=25-step=224.ckpt')

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = model.eval()
model.to(device)

In [None]:
dice_score = metrics.DiceMetric()

In [None]:
score = 0
for index in range(len(val_dataset)):
  x = val_dataset[index]['MRI']['data']
  y = val_dataset[index]['label']['data']
  grid_sampler = tio.inference.GridSampler(val_dataset[index],64,(8,8,8))
  aggregator = tio.inference.GridAggregator(grid_sampler)
  patch_loader = torch.utils.data.DataLoader(grid_sampler,batch_size = 8)
  with torch.no_grad():
    for patches_batch in patch_loader:
      input_tensor = patches_batch['MRI']['data'].to(device)
      locations = patches_batch[tio.LOCATION]
      pred = model(input_tensor)
      aggregator.add_batch(pred,locations)
  output_tensor = aggregator.get_output_tensor()
  y_pred = output_tensor.argmax(0)
  score += dice_score(y.unsqueeze(dim = 0),y_pred.unsqueeze(dim = 0).unsqueeze(dim = 0))
total_dice_score = score/(len(val_dataset))

In [None]:
print(total_dice_score)

tensor([[0.4934]])


# **visualization on test data**

In [None]:
index = 4
x = val_dataset[index]['MRI']['data']
y = val_dataset[index]['label']['data']
grid_sampler = tio.inference.GridSampler(val_dataset[index],64,(8,8,8))
aggregator = tio.inference.GridAggregator(grid_sampler)
patch_loader = torch.utils.data.DataLoader(grid_sampler,batch_size = 8)
with torch.no_grad():
  for patches_batch in patch_loader:
    input_tensor = patches_batch['MRI']['data'].to(device)
    locations = patches_batch[tio.LOCATION]
    pred = model(input_tensor)
    aggregator.add_batch(pred,locations)
output_tensor = aggregator.get_output_tensor()
y_pred = output_tensor.argmax(0)
x = x[0]
y_true = y[0]

In [None]:
fig,ax = plt.subplots()
camera = Camera(fig)
label_names = ['background','edema','non-enhancing tumor','enhancing tumour']
clim = [0, len(label_names) - 1]
for i in range(x.shape[2]):
  im = plt.imshow(x[:,:,i],cmap = 'gray',)
  im_labels = plt.imshow(y_true[:,:,i],alpha = 0.5)
  im_labels.set_clim(clim)
  camera.snap()
cbar = plt.colorbar(im_labels, ax=ax,ticks = range(len(label_names)))
cbar.set_ticklabels(label_names)
cbar.set_label('Label')
animation = camera.animate()

In [None]:
HTML(animation.to_html5_video())

In [None]:
fig,ax = plt.subplots()
camera = Camera(fig)
label_names = ['background','edema','non-enhancing tumor','enhancing tumour']
clim = [0, len(label_names) - 1]
for i in range(x.shape[2]):
  im = plt.imshow(x[:,:,i],cmap = 'gray',)
  im_labels = plt.imshow(y_pred[:,:,i],alpha = 0.5)
  im_labels.set_clim(clim)
  camera.snap()
cbar = plt.colorbar(im_labels, ax=ax,ticks = range(len(label_names)))
cbar.set_ticklabels(label_names)
cbar.set_label('Label')
animation = camera.animate()

In [None]:
HTML(animation.to_html5_video())