## Data Loading and preprocessing

In [None]:
# !pip install SimpleITK
# !pip install torchio
# !pip install monai
# !pip install einops

# !unzip /content/drive/MyDrive/COVID-19-20_v2.zip

# %load_ext tensorboard
# %tensorboard --logdir runs

import torch
import torchvision
import numpy as np
import pandas as pd
import torchio as tio
import matplotlib.pyplot as plt
import SimpleITK as sitk
import functools
import logging
import einops

from monai.networks.nets import VNet
from monai.losses import DiceLoss
from functools import cached_property
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/covid_segmentation')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger = logging.getLogger('CT_logger')
logger.setLevel(logging.DEBUG)
file_log = logging.FileHandler('ct.log')
file_log.setLevel(logging.DEBUG)
logger.addHandler(file_log)
logger.propagate = False

In [None]:
xlsx = pd.ExcelFile('/content/COVID-19-20_TrainValidation.xlsx')
df_train = pd.read_excel(xlsx, 'Train set')
df_test = pd.read_excel(xlsx, 'Validation set')

val_split = int(len(df_train)*0.8)

train_namelist = list(df_train['FILENAME'])[:val_split]
val_namelist = list(df_train['FILENAME'])[val_split:]
test_namelist = list(df_test['FILENAME'])

In [None]:
class CT:
  def __init__(self, CT_ID, folder_path):
    self.CT_ID = CT_ID
    self.folder_path = folder_path
  
  @cached_property
  def volume(self):
    try: #change that
      ct = sitk.ReadImage(self.folder_path + self.CT_ID + '.nii')
    except Exception:
      ct = sitk.ReadImage(self.folder_path + self.CT_ID + '_ct.nii')
    mask = sitk.ReadImage(self.folder_path + self.CT_ID + '_seg.nii')

    ct_np = sitk.GetArrayFromImage(ct)
    mask_np = sitk.GetArrayFromImage(mask)

    ct_np = self.preprocessing(ct_np)
    mask_np = self.preprocessing(mask_np)

    return (ct_np, mask_np)
  
  def preprocessing(self, image):
    image = np.clip(image, -1000, 1000)
    image = image[:32] # change that

    return image

In [None]:
def augment(ct, mask, aug_type):
  aug_dict = {'flip': tio.RandomFlip(), 
              'ED': tio.RandomElasticDeformation(),
              'affine': tio.RandomAffine(),
              'anistropy': tio.RandomAnisotropy(),
              'noise': tio.RandomNoise(),
              'blur': tio.RandomBlur(), 
              'swap': tio.RandomSwap()}
  
  combined = torch.cat((ct.unsqueeze(0), mask.unsqueeze(0)), 0)

  ct, mask = aug_dict[aug_type](combined)

  return ct, mask

In [None]:
class CTDataset(Dataset):
  def __init__(self, namelist, folder_path, augmentation_list=[]):
    self.folder_path = folder_path
    namelist_l = len(namelist)
    self.namelist = zip(namelist, [None]*namelist_l)
    self.namelist = list(self.namelist)
    
    if augmentation_list:
      for aug_type in augmentation_list:
        self.namelist += list(zip(namelist, [aug_type]*namelist_l))

  def __len__(self):
    return len(self.namelist)
  
  def __getitem__(self, idx):
    ct_id, aug_type = self.namelist[idx]
    ct, mask = CT(ct_id, self.folder_path).volume
    ct, mask = torch.from_numpy(ct), torch.from_numpy(mask)

    if aug_type:
      ct, mask = augment(ct, mask, aug_type)

    return ct, mask

In [None]:
trainset = CTDataset(train_namelist, '/content/Train/',
                     augmentation_list=['flip', 'affine', 'noise', 'swap'])

valset = CTDataset(val_namelist, '/content/Train/')

trainloader = DataLoader(trainset, batch_size=1, shuffle=True)
valloader = DataLoader(valset, batch_size=1, shuffle=True)

## Model training

In [None]:
model = VNet(in_channels=1, out_channels=1).to(device)

In [None]:
class ModelEvaluation:
  def __init__(self, model, loader):
    self.model = model
    self.loader = loader

    self.loss_fn = self.DiceLoss
    self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
  
  def DiceLoss(self, pred_mask, true_mask): # fix dimension issue
    inter = (pred_mask * true_mask).sum(dim=[1,2,3,4])
    union = pred_mask.sum(dim=[1,2,3,4]) + true_mask.sum(dim=[1,2,3,4])
    dice = 1-(2*inter + 1)/(union + 1)
    
    return dice.mean()
  
  def train(self, epochs=1):

    gc.collect()
    torch.cuda.empty_cache()

    self.model.train()
    for epoch in range(epochs):
      for batch, (ct, mask) in enumerate(self.loader):
        ct = ct.unsqueeze(1).to(device, dtype=torch.float32)
        mask = mask.unsqueeze(1).to(device, dtype=torch.float32)
        model_out = self.model(ct)

        loss = self.loss_fn(model_out, mask)
        if batch % 10 == 0: 
          writer.add_scalar('Loss/train', loss, batch)
          logger.debug(f'DEBUG| location: ModelEvalutation.train | epoch: {epoch}, batch: {batch} loss: {loss.item()}')
          
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
      
  
    writer.add_graph(self.model, ct)
    writer.flush()

In [None]:
evaluate = ModelEvaluation(model, trainloader)

In [None]:
import gc
evaluate.train()

KeyboardInterrupt: ignored