In [4]:
!pip install pytorch_lightning
!git clone https://github.com/black0017/MedicalZooPytorch.git
!pip install -r MedicalZooPytorch/installation/requirements.txt
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/94/7b/0d5ef515a695fa55a29786297f9e8bd6e0f35a689decd53574cdd80597bc/pytorch_lightning-1.1.1-py3-none-any.whl (669kB)
[K     |████████████████████████████████| 675kB 19.5MB/s 
Collecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 60.8MB/s 
[?25hCollecting fsspec>=0.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/ec/80/72ac0982cc833945fada4b76c52f0f65435ba4d53bc9317d1c70b5f7e7d5/fsspec-0.8.5-py3-none-any.whl (98kB)
[K     |████████████████████████████████| 102kB 12.4MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 50.1MB/s 
Buil

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.7.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.7.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (735.4MB)
[K     |████████████████████████████████| 735.4MB 17kB/s 
[?25hCollecting torchvision==0.8.2+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.8.2%2Bcu101-cp36-cp36m-linux_x86_64.whl (12.8MB)
[K     |████████████████████████████████| 12.8MB 59.4MB/s 
[?25hCollecting torchaudio==0.7.2
[?25l  Downloading https://files.pythonhosted.org/packages/2a/f9/618434cf4e46dc975871e1516f5499abef6564ab4366f9b2321ee536be14/torchaudio-0.7.2-cp36-cp36m-manylinux1_x86_64.whl (7.6MB)
[K     |████████████████████████████████| 7.6MB 7.5MB/s 
Installing collected packages: torch, torchvision, torchaudio
  Found existing installation: torch 1.4.0
    Uninstalling torch-1.4.0:
      Successfully uninstalled torch-1.4.0
  Found existing installation: torchvision 0.

In [5]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
import os
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
import sys
sys.path.append('./MedicalZooPytorch')
from lib import medzoo
import nibabel as nb
from skimage import transform

In [None]:
from pytorch_lightning.metrics import Metric

#Not sure how to implement this yet, but this is the outline of the metric
class IntersectionOverUnion(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("intersection", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("union", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape
  
        self.intersection = torch.logical_and(target, preds)
        self.union = torch.logical_or(target,preds)
        iou_score = torch.sum(intersection) / torch.sum(union)

    def compute(self):
        return self.iou_score

In [11]:
#Pytorch-lightning setup
class TumourSegmentation(pl.LightningModule):
  def __init__(self,model):
    super().__init__()
    self.model = model
  def forward(self,x):
    x=x.float()
    return self.model.forward(x)
  def training_step(self, batch, batch_idx):
    x, y = batch
    x=x.float()
    y=y.float()
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    return loss
  def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=0.02)


In [12]:
#model initialization
unet_model =  medzoo.Unet3D.UNet3D(in_channels=4,n_classes=2,base_n_filter=8)
model = TumourSegmentation(unet_model)

In [5]:
#Dataset
class brats_dataset(torch.utils.data.Dataset):
  def __init__(self,data_folders):
    self.data_list = []
    
    #Perform necessary input data preparation in this function
    #add each input example into the data_last function
    #takes in a list of folders and processes the data contained
    for folder in data_folders:
      f_flair = nb.load(os.path.join(folder,'BraTS20_Training_001_flair.nii'),mmap=False).get_fdata()
      f_seg = nb.load(os.path.join(folder,'BraTS20_Training_001_seg.nii'),mmap=False).get_fdata()
      f_t1ce = nb.load(os.path.join(folder,'BraTS20_Training_001_t1ce.nii'),mmap=False).get_fdata()
      f_t1 =  nb.load(os.path.join(folder,'BraTS20_Training_001_t1.nii'),mmap=False).get_fdata() 
      f_t2 = nb.load(os.path.join(folder,'BraTS20_Training_001_t2.nii'),mmap=False).get_fdata()
      
      #Resizing because UNet input shape requirements
      f_t1 = transform.resize(f_t1, [320, 400, 320])
      f_t2 = transform.resize(f_t2, [320, 400, 320])
      f_t1ce = transform.resize(f_t1ce, [320, 400, 320])
      f_flair = transform.resize(f_flair, [320, 400, 320])
      f_seg = transform.resize(f_seg, [320, 400, 320])
      self.data_list.append([np.stack([f_t1, f_t1ce, f_t2, f_flair]), f_seg])
  def __len__(self):
    return len(self.data_list)
  def __getitem__(self, index):
    return self.data_list[index]

In [6]:
data_folders = ['./drive/MyDrive/sample_data/BraTS20_Training_001']
dataset = brats_dataset(data_folders)
train_dataset, val_dataset = random_split(dataset, lengths=[1,0])

In [20]:
def train_collate_fn(batch):
  pass
  #this is where we would perform data augmentation on the input batch
  

In [9]:
#Data Loader
train_dataloader = DataLoader(train_dataset)#,collate_fn=train_collate_fn)
#val_dataloader = Dataloader(val_dataset)

In [None]:
#finding batch size and learning rate
#trainer = pl.Trainer(auto_scale_batch_size='binsearch')
#trainer.tune(model)
#trainer = pl.Trainer(auto_lr_find = True,)
#trainer.tune(model)

In [17]:
#Training
trainer = pl.Trainer(
    accumulate_grad_batches = 1,
    gpus = 1,
    max_epochs = 1,
    #check_val_every_n_epoch = 1,

)
#trainer.tune()
trainer.fit(model=model,train_dataloader=train_dataloader)#,val_dataloaders=val_dataloader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | UNet3D | 1.8 M 
---------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




RuntimeError: ignored

In [None]:
#Run this cell then the one below to empty gpu memory, CUDA is wierd
1/0 

In [15]:
import gc
gc.collect()
torch.cuda.empty_cache()
t = torch.cuda.get_device_properties(0).total_memory
c = torch.cuda.memory_cached(0)
a = torch.cuda.memory_allocated(0)
print(t,c,a)

