## Data Loading and preprocessing

In [1]:
# !pip install SimpleITK
# !pip install torchio
# !pip install monai
# !pip install einops
# !pip install tensorboard-plugin-3d
# !pip install pynvml

# !unzip /content/drive/MyDrive/COVID-19-20_v2.zip

# %load_ext tensorboard
# %tensorboard --logdir runs

import gc
import torch
import torchvision
import numpy as np
import pandas as pd
import torchio as tio
import matplotlib.pyplot as plt
import SimpleITK as sitk
import functools
import logging
import einops


from torch import nn
from scipy.ndimage import zoom
from torchvision import transforms
from monai.networks.nets import VNet, UNETR, SwinUNETR
from monai.metrics import compute_generalized_dice, compute_average_surface_distance
from monai.metrics import compute_surface_dice, compute_roc_auc, compute_iou
from monai.metrics import compute_hausdorff_distance, CumulativeAverage
from monai.visualize.img2tensorboard import plot_2d_or_3d_image
from monai.losses import GeneralizedDiceLoss, DiceLoss, DiceCELoss
from functools import cached_property
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/covid_segmentation')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger = logging.getLogger('CT_logger')
logger.setLevel(logging.DEBUG)
file_log = logging.FileHandler('ct.log')
file_log.setLevel(logging.DEBUG)
logger.addHandler(file_log)
logger.propagate = False

In [2]:
xlsx = pd.ExcelFile('/content/COVID-19-20_TrainValidation.xlsx')
df_train = pd.read_excel(xlsx, 'Train set')
df_test = pd.read_excel(xlsx, 'Validation set')

val_split = int(len(df_train)*0.8)

train_namelist = list(df_train['FILENAME'])[:val_split]
val_namelist = list(df_train['FILENAME'])[val_split:]
test_namelist = list(df_test['FILENAME'])

In [3]:
class CT:
  def __init__(self, CT_ID, folder_path):
    self.CT_ID = CT_ID
    self.folder_path = folder_path
  
  @cached_property
  def volume(self):
    try: #change that
      ct = sitk.ReadImage(self.folder_path + self.CT_ID + '.nii')
    except Exception:
      ct = sitk.ReadImage(self.folder_path + self.CT_ID + '_ct.nii')
    mask = sitk.ReadImage(self.folder_path + self.CT_ID + '_seg.nii')

    ct_np = sitk.GetArrayFromImage(ct)
    mask_np = sitk.GetArrayFromImage(mask)

    ct_np = np.clip(ct_np, -1000, 1000)
    mask_np = np.clip(mask_np, -1000, 1000)

    # ct_np = zoom(ct_np, (1, 0.5, 0.5)) # test !!!!
    # mask_np = zoom(mask_np, (1, 0.5, 0.5))

    ct_tr = torch.from_numpy(ct_np).to(dtype=torch.float32).permute(1,2,0)
    mask_tr = torch.from_numpy(mask_np).to(dtype=torch.float32).permute(1,2,0)

    ct_tr = torch.nn.functional.interpolate(ct_tr, 48).permute(2,0,1)
    mask_tr = torch.nn.functional.interpolate(mask_tr, 48).permute(2,0,1)

    return (ct_tr, mask_tr)

In [4]:
class CTDataset(Dataset):
  def __init__(self, namelist, folder_path, transforms=None):
    self.folder_path = folder_path
    self.namelist = namelist
    self.transforms = transforms

  def __len__(self):
    return len(self.namelist)
  
  def __getitem__(self, idx):
    ct_id = self.namelist[idx]
    ct, mask = CT(ct_id, self.folder_path).volume

    if self.transforms:
      randtransf = transforms.RandomChoice(self.transforms)
      combined = torch.cat((ct.unsqueeze(0), mask.unsqueeze(0)), 0)
      ct, mask = randtransf(combined)

    return ct, mask

## Model training

In [5]:
class ModelEvaluation:
  def __init__(self, model, loader_dict):
    self.model = model
    self.loader_dict = loader_dict

    self.loss_fn = GeneralizedDiceLoss(sigmoid=True)
    self.optimizer = torch.optim.Adam(self.model.parameters(), weight_decay=0.001)

  
  def train(self, epochs=1, batch_size=1):

    gc.collect()
    torch.cuda.empty_cache()

    
    self.model.train()
    for epoch in range(epochs):
      loss_cumul = CumulativeAverage()
      iou_cumul = CumulativeAverage()
      for batch, (ct, mask) in enumerate(self.loader_dict['train']):
        ct = ct.unsqueeze(1).to(device)
        mask = mask.unsqueeze(1).to(device)
        model_out = self.model(ct)
        loss = self.loss_fn(model_out.permute(0,1,3,4,2), mask.permute(0,1,3,4,2)) # temp

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        gdl, iou = self.compute_metrics(model_out, mask)
        loss_cumul.append(loss, count=batch_size)
        iou_cumul.append(iou, count=batch_size)
        

        if batch == 0: # temporary
          logger.debug(f'DEBUG| location: ModelEvalutation.train | min: \
          {float(torch.min(model_out))}, max: {float(torch.max(model_out))}')

        ct.detach().cpu()
        model_out = model_out.detach().cpu()
        mask = mask.detach().cpu()
        sigmoid = nn.Sigmoid()
        model_out, mask = sigmoid(model_out), sigmoid(mask)

        if batch % 10 == 0:
          writer.add_scalar('Loss/train', loss, batch)
          writer.add_scalar('GDL/train', gdl, batch)
          plot_2d_or_3d_image((model_out > 0.5), step=batch, writer=writer, tag='model_out') # !
          plot_2d_or_3d_image(mask, step=batch, writer=writer, tag='mask')
          logger.debug(f'DEBUG| location: ModelEvalutation.train | epoch: {epoch}, batch: {batch}, loss: {loss.item()}, IOU: {iou}')
      

      loss_avg = loss_cumul.aggregate()
      iou_avg = iou_cumul.aggregate()
      logger.debug(f'DEBUG| location: ModelEvalutation.train | loss_avg: {loss_avg}, iou_avg: {iou_avg}')

    writer.flush()
  
  def evaluate(self, mode='val', batch_size=1):

    gc.collect()
    torch.cuda.empty_cache()

    run_avg = CumulativeAverage()

    self.model.eval()
    with torch.no_grad():
      for batch, (ct, mask) in enumerate(self.loader_dict[mode]):
        ct = ct.unsqueeze(1).to(device, dtype=torch.float32)
        mask = mask.unsqueeze(1).to(device, dtype=torch.float32)
        model_out = self.model(ct)
        loss = self.loss_fn(model_out, mask)
        run_avg.append(loss, count=batch_size)

        if batch % 10 == 0: 
          gdl, iou = self.compute_metrics(model_out, mask)
          writer.add_scalar(f'Loss/{mode}', loss, batch)
          writer.add_scalar(f'GDL/{mode}', gdl, batch)
          plot_2d_or_3d_image((model_out > 0.5), step=batch, writer=writer, tag='model_out')
          plot_2d_or_3d_image(mask, step=batch, writer=writer, tag='mask')
    
    avg_loss = run_avg.aggregate()
    print('avg_loss:', avg_loss)

    writer.flush()
  
  def to_monai_shape(self, y_pred, y):
    y[0][0][0][0][0] = 1
    y_pred[0][0][0][0][0] = 1

    y_pred = y_pred.permute(0,1,3,4,2)
    y_pred = y_pred > 0.5
    y = y.permute(0,1,3,4,2)
    y = y > 0.5
    return (y_pred, y)

  def compute_metrics(self, y_pred, y):
    y_pred, y = self.to_monai_shape(y_pred, y)
    GDL = compute_generalized_dice(y_pred, y)
    IOU = compute_iou(y_pred, y)

    return (float(GDL), float(IOU))

## Execution

In [6]:
# model = VNet(in_channels=1, out_channels=1).to(device)
model = UNETR(in_channels=1, out_channels=1, img_size=(48,512,512), dropout_rate=0.6).to(device)

In [7]:
augmentations = [transforms.RandomRotation(180), transforms.RandomAffine(180),
                 transforms.ElasticTransform(), transforms.RandomHorizontalFlip(p=0.5),
                 transforms.RandomVerticalFlip(p=0.5), transforms.GaussianBlur(1)]

In [8]:
trainset = CTDataset(train_namelist, '/content/Train/', transforms=augmentations)

valset = CTDataset(val_namelist, '/content/Train/')
testset = CTDataset(test_namelist, '/content/Validation/')

batch_size = 1

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size)

loader_dict = {'train': trainloader, 'val': valloader, 'test': testloader}

In [9]:
evaluate = ModelEvaluation(model, loader_dict)
evaluate.train(epochs=3, batch_size=batch_size)

KeyboardInterrupt: ignored