In [1]:
!pip install pytorch_lightning
#!git clone https://github.com/black0017/MedicalZooPytorch.git
#!pip install -r MedicalZooPytorch/installation/requirements.txt
#!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html



In [2]:
!pip install torchsummaryX



In [3]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
cd drive/MyDrive/MacAI

/content/drive/MyDrive/MacAI


In [5]:
import os
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
import sys

from lib import medzoo
import nibabel as nb
from skimage import transform
import matplotlib.pyplot as plt


In [7]:
#Pytorch-lightning setup
class TumourSegmentation(pl.LightningModule):
  def __init__(self,model):
    super().__init__()
    self.model = model
  def forward(self,x):
  #  x=x.half()
    f = self.model.forward(x)

  #  print('Done forward step!')
    return f

  def training_step(self, batch, batch_idx):

    x, y = batch

    y_hat = self.model(x)
    
    # I'm not really sure why the shape is weird here, but this seems to run
    y_hat = torch.squeeze(y_hat,axis=1) 

    loss = torch.mean(torch.abs(y_hat - y))
    # this CE results in a CUDA error because the U-net implementation is strange
    #F.binary_cross_entropy(y_hat, y)
    return loss
  def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=0.02)


In [8]:
#model initialization
#.half()
unet_model =  medzoo.Unet3D.UNet3D(in_channels=4, n_classes=1, base_n_filter=4).cuda()#medzoo.Unet3D.UNet3D(in_channels=4,n_classes=1,base_n_filter=8)
model = TumourSegmentation(unet_model)

In [9]:
#Dataset
class brats_dataset(torch.utils.data.Dataset):
  def __init__(self,data_folders):
    self.data_list = []
    
    #Perform necessary input data preparation in this function
    #add each input example into the data_last function
    #takes in a list of folders and processes the data contained

    # U net requires all dimensions be divisible by 8 (by default)
    # or we'd have to manually do the padding in the U-net model
    # no padding="valid" exists in Pytorch for... reasons?
    for i, folder in enumerate(data_folders):
      i_str = folder[-3:]

      f_flair = nb.load(os.path.join(folder,'BraTS20_Training_%s_flair.nii' % i_str),mmap=False).get_fdata()
      f_seg = nb.load(os.path.join(folder,'BraTS20_Training_%s_seg.nii'% i_str),mmap=False).get_fdata()
      f_t1ce = nb.load(os.path.join(folder,'BraTS20_Training_%s_t1ce.nii'% i_str),mmap=False).get_fdata()
      f_t1 =  nb.load(os.path.join(folder,'BraTS20_Training_%s_t1.nii'% i_str),mmap=False).get_fdata() 
      f_t2 = nb.load(os.path.join(folder,'BraTS20_Training_%s_t2.nii'% i_str),mmap=False).get_fdata()

      f_flair = torch.as_tensor(np.expand_dims(np.pad(f_flair, [(0, 0), (0, 0), (2, 3)]), axis=0)).half()
      f_t1 = torch.as_tensor(np.expand_dims(np.pad(f_t1, [(0, 0), (0, 0), (2, 3)]), axis=0)).half()
      f_t2 = torch.as_tensor(np.expand_dims(np.pad(f_t2, [(0, 0), (0, 0), (2, 3)]), axis=0)).half()
      f_seg = torch.as_tensor(np.expand_dims(np.pad(f_seg, [(0, 0), (0, 0), (2, 3)]), axis=0)).half()
      f_t1ce = torch.as_tensor(np.expand_dims(np.pad(f_t1ce, [(0, 0), (0, 0), (2, 3)]), axis=0)).half()


      concat = torch.cat([f_t1, f_t1ce, f_t2, f_flair], axis=0)

      self.data_list.append([concat, f_seg])
  def __len__(self):
    return len(self.data_list)
  def __getitem__(self, index):
    return self.data_list[index]

In [10]:
data_folders = ['sample_data/BraTS20_Training_%s' % str(x).zfill(3) for x in range(1, 11)]
dataset = brats_dataset(data_folders)
train_dataset, val_dataset = random_split(dataset, lengths=[8,2])

In [11]:
def train_collate_fn(batch):
  pass
  #this is where we would perform data augmentation on the input batch
  

In [12]:
#Data Loader
train_dataloader = DataLoader(train_dataset)#,collate_fn=train_collate_fn)
#val_dataloader = Dataloader(val_dataset)

In [15]:
#Training
trainer = pl.Trainer(
    accumulate_grad_batches = 1,
    gpus = 1,
    max_epochs = 1,
    precision=16,
    #check_val_every_n_epoch = 1,

)
#trainer.tune()
trainer.fit(model=model,train_dataloader=train_dataloader)#,val_dataloaders=val_dataloader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name  | Type   | Params
---------------------------------
0 | model | UNet3D | 445 K 
---------------------------------
445 K     Trainable params
0         Non-trainable params
445 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [None]:

#outputs = unet_model(torch.unsqueeze(dataset.data_list[0][0], axis=0).cuda())
#loss = torch.mean(torch.abs(torch.squeeze(outputs,axis=1) - dataset.data_list[0][1].cuda()))

#optimizer = torch.optim.Adam(unet_model.parameters()) 
#optimizer.zero_grad()
#loss.backward()
#optimizer.step()
#print(loss)

In [16]:
    # to be used later
#def bbox2_3D(img):

#    r = np.any(img, axis=(1, 2))
#    c = np.any(img, axis=(0, 2))
#    z = np.any(img, axis=(0, 1))

#    rmin, rmax = np.where(r)[0][[0, -1]]
#    cmin, cmax = np.where(c)[0][[0, -1]]
#    zmin, zmax = np.where(z)[0][[0, -1]]

#    return rmin, rmax, cmin, cmax, zmin, zmax
    



      xmins = []
      xmaxs = []
      ymins = []
      ymaxs = []
      zmins = []
      zmaxs = []
      for img in [f_flair, f_seg, f_t1ce, f_t1, f_t2]:
        xmin, xmax, ymin, ymax, zmin, zmax = bbox2_3D(img)
        xmins.append(xmin)
        xmaxs.append(xmax)
        ymins.append(ymin)
        ymaxs.append(ymax)
        zmins.append(zmin)
        zmaxs.append(zmax)
      xmin = np.min(xmin)
      ymin = np.min(ymin)
      zmin = np.min(zmin)
      xmax = np.max(xmax)
      ymax = np.max(ymax)
      zmax = np.max(zmax)

      f_flair = f_flair[xmin:xmax, ymin:ymax, zmin:zmax]
      f_seg = f_seg[xmin:xmax, ymin:ymax, zmin:zmax]
      f_t1ce = f_t1ce[xmin:xmax, ymin:ymax, zmin:zmax]
      f_t1 = f_t1[xmin:xmax, ymin:ymax, zmin:zmax]
      f_t2 = f_t2[xmin:xmax, ymin:ymax, zmin:zmax]

      print(f_flair.shape)