## Data Loading and preprocessing

In [2]:
!pip install monai
!pip install einops
!pip install pynvml
!pip install tensorboard-plugin-customizable-plots

import os
import torch
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import logging
import einops
import random


from torch import nn
from torchvision import transforms
from torchvision.io import read_image
from monai.losses import DiceLoss
from monai.networks.nets import UNet, BasicUNet, FlexibleUNet
from monai.networks.nets import SegResNet, UNETR, SwinUNETR
from monai.metrics import compute_iou, compute_generalized_dice, CumulativeAverage
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/covid_segmentation')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = logging.getLogger('CT_logger')
logger.setLevel(logging.DEBUG)
file_log = logging.FileHandler('ct.log')
file_log.setLevel(logging.DEBUG)
logger.addHandler(file_log)
logger.propagate = False

In [3]:
def get_path_dict() -> list:
  ''' creates a list of image addresses 
      return:
        names_dict: dictionary, that contains
                    a list of image paths for each set
  '''
  
  names = os.listdir('/content/PNG_Covid/frames')
  random.shuffle(names)
  namelist = []
  names_dict = {}
  for name in names:
    image = '/content/PNG_Covid/frames/' + name
    mask = '/content/PNG_Covid/masks/' + name
    namelist.append((image, mask))

  # split into three subsets
  set_len = len(namelist)
  names_dict['Train'] = namelist[:int(set_len*0.6)]
  names_dict['Val'] = namelist[int(set_len*0.6):int(set_len*0.8)]
  names_dict['Test'] = namelist[int(set_len*0.8):]
  
  return names_dict

In [6]:
def get_set(namelist: list, augment=False) -> list:
  ''' load and preproccess images 
      params:
        namelist: list of imge paths
        augment: wether to augment dataset
      return:
        loaded: list of loaded images
  '''

  loaded = []
  for image, mask in namelist:
    image, mask = get_images(image, mask)

    if augment == True:
      loaded += augment_set(image, mask)
    else:
      loaded.append((image, mask))

  return loaded

def get_images(image_path: str, mask_path: str) -> tuple:
  ''' load image and mask '''

  image, mask = read_image(image_path), read_image(mask_path)
  image, mask = to_standart_format(image, mask)

  return (image, mask)

def to_standart_format(image: torch.tensor, mask: torch.tensor) -> tuple:
  ''' transform to standart image format '''
  
  resize = torchvision.transforms.Resize(224)
  image, mask = resize(image), resize(mask)
  
  image = image/255
  mask = mask/255

  return (image, mask)

def augment_set(image: torch.tensor, mask: torch.tensor) -> list:
  ''' augment dataset '''

  new_loaded_list = []
  augmentations = [transforms.RandomRotation(180), transforms.RandomAffine(180),
                   transforms.RandomHorizontalFlip(p=1), transforms.RandomVerticalFlip(p=1)]
  
  new_loaded_list.append((image, mask))
  combined = torch.cat((image.unsqueeze(0), mask.unsqueeze(0)), 0) # compbine to augment simultaneously

  for augment in augmentations:
    aug_image, aug_mask = augment(combined)
    new_loaded_list.append((aug_image, aug_mask))

  return new_loaded_list

In [7]:
class CTDataset(Dataset):
  ''' create dataset of ct images 
      args:
        path_dict: dictionary of sets and their list of paths
        mode: a mode of evaluation
        augment: wether to augment dataset
  '''

  def __init__(self, path_dict: dict, mode='Train', augment=False):
    self.namelist = get_set(path_dict[mode], augment)

  def __len__(self):
    return len(self.namelist)
  
  def __getitem__(self, idx: int):
    image, mask = self.namelist[idx]
    return (image, mask)

## Model training

In [8]:
class ModelEvaluation:
  ''' train and evaluate model
      args:
        model: model object
        loade_dict: dictionary of loaders for each set
  '''

  def __init__(self, model: nn.Module, loader_dict: dict):
    self.model = model
    self.loader_dict = loader_dict

    self.loss_fn = DiceLoss(sigmoid=True)
    self.optimizer = torch.optim.Adam(self.model.parameters())
    self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')

  
  def train(self, epochs=1, batch_size=1):
    self.model.train()
    torch.backends.cudnn.benchmark = True
    for epoch in range(epochs):
      loss_cumul = CumulativeAverage()
      iou_cumul = CumulativeAverage()
      gd_cumul = CumulativeAverage()
      recall_cumul = CumulativeAverage()
      precision_cumul = CumulativeAverage()
      f1_cumul = CumulativeAverage()
      for batch, (ct, mask) in enumerate(self.loader_dict['train']):
        ct = ct.to(device)
        mask = mask.to(device)

        self.optimizer.zero_grad()
        model_out = self.model(ct)
        loss = self.loss_fn(model_out, mask)

        loss.backward()
        self.optimizer.step()
      
        model_out = nn.Sigmoid()(model_out)
        gd, iou, recall, precision, f1 = self.compute_metrics(model_out, mask)
        loss_cumul.append(loss)
        iou_cumul.append(iou)
        gd_cumul.append(gd)
        recall_cumul.append(recall)
        precision_cumul.append(precision)
        f1_cumul.append(f1)

      loss_avg = loss_cumul.aggregate()
      iou_avg = iou_cumul.aggregate()
      gd_avg = gd_cumul.aggregate()
      recall_avg = recall_cumul.aggregate()
      precision_avg = precision_cumul.aggregate()
      f1_avg = f1_cumul.aggregate()

      writer.add_scalar('Loss_AVG/train', loss_avg, epoch)
      writer.add_scalar('IOU_AVG/train', iou_avg, epoch)
      writer.add_scalar('GD_AVG/train', gd_avg, epoch)
      writer.add_scalar('Recall_AVG/train', recall_avg, epoch)
      writer.add_scalar('Precision_AVG/train', precision_avg, epoch)
      writer.add_scalar('F1_AVG/train', f1_avg, epoch)
      logger.debug(f'DEBUG| location: ModelEvalutation.train | loss_avg: {loss_avg}, iou_avg: {iou_avg}')
      self.scheduler.step(loss)

    writer.flush()
  

  def evaluate(self, mode='val', batch_size=1):
    self.model.eval()
    loss_cumul = CumulativeAverage()
    iou_cumul = CumulativeAverage()
    gd_cumul = CumulativeAverage()
    recall_cumul = CumulativeAverage()
    precision_cumul = CumulativeAverage()
    f1_cumul = CumulativeAverage()
    with torch.no_grad():
      for batch, (ct, mask) in enumerate(self.loader_dict[mode]):
        ct = ct.to(device)
        mask = mask.to(device)
        model_out = self.model(ct)
        loss = self.loss_fn(model_out, mask)
        
        model_out = nn.Sigmoid()(model_out)
        gd, iou, recall, precision, f1 = self.compute_metrics(model_out, mask)
        loss_cumul.append(loss, count=batch_size)

        iou_cumul.append(iou)
        gd_cumul.append(gd)
        recall_cumul.append(recall)
        precision_cumul.append(precision)
        f1_cumul.append(f1)

        if batch % 50 == 0: 
           writer.add_images(f'model_out/{mode}', (model_out > 0.5), global_step=batch)
           writer.add_images(f'mask/{mode}', mask, global_step=batch)
    
    loss_avg = loss_cumul.aggregate()
    iou_avg = iou_cumul.aggregate()
    gd_avg = gd_cumul.aggregate()
    recall_avg = recall_cumul.aggregate()
    precision_avg = precision_cumul.aggregate()
    f1_avg = f1_cumul.aggregate()
            
    print(f'loss_avg: {loss_avg}, iou_avg: {iou_avg}, gd_avg: {gd_avg}')
    print(f'recall_avg: {recall_avg}, precision_avg: {precision_avg}, f1_avg: {f1_avg}')

    writer.flush()
  
  def to_monai_form(self, y_pred: torch.tensor, y: torch.tensor) -> tuple:
    ''' transform to monai-compatible form '''
    y_pred = y_pred > 0.5
    y = y > 0.5
    return (y_pred, y)

  def compute_metrics(self, y_pred: torch.tensor, y: torch.tensor) -> list:
    ''' compute Monai metrics '''

    y_pred, y = self.to_monai_form(y_pred, y)
    GD = compute_generalized_dice(y_pred, y).mean()
    IOU = compute_iou(y_pred, y, ignore_empty=False).mean()

    recall, precision, f1 = self.additional_metrics(y_pred, y)
    
    metrics = [float(GD), float(IOU), float(recall),
               float(precision), float(f1)]

    return metrics
  
  def additional_metrics(self, y_pred: torch.tensor, y: torch.tensor) -> tuple:
    ''' compute additional metrics '''

    inter = (y_pred * y).sum(dim=[1,2,3])
    
    recall = (inter + 1)/(y.sum(dim=[1,2,3]) + 1)
    precision = (inter + 1)/(y_pred.sum(dim=[1,2,3]) + 1)
    f1 = 2*((precision*recall)/(precision+recall))

    return recall.mean(), precision.mean(), f1.mean()

## Execution

In [9]:
swin = SwinUNETR(in_channels=3, out_channels=3, img_size=(224,224), drop_rate=0.5, spatial_dims=2, use_checkpoint=True).to(device)
unet = UNet(spatial_dims=2, in_channels=3, out_channels=3, channels=(4, 8, 16), strides=(2, 2)).to(device)
basic = BasicUNet(spatial_dims=2, in_channels=3, out_channels=3, dropout=0.5).to(device)
unetr = UNETR(in_channels=3, out_channels=3, img_size=(224, 224), dropout_rate=0.5, spatial_dims=2).to(device)
segresnet = SegResNet(in_channels=3, out_channels=3, dropout_prob=0.5, spatial_dims=2).to(device)
flexible = FlexibleUNet(in_channels=3, out_channels=3, backbone='efficientnet-b0', spatial_dims=2).to(device)

BasicUNet features: (32, 32, 64, 128, 256, 32).


In [None]:
model = swin # pick what model to use

In [None]:
path_dict = get_path_dict()

trainset = CTDataset(path_dict, 'Train', True)
valset = CTDataset(path_dict, 'Val')
testset = CTDataset(path_dict, 'Test')

batch_size = 32
num_workers = 4

trainloader = DataLoader(trainset, batch_size=batch_size,
                         shuffle=True, num_workers=num_workers, pin_memory=True)
valloader = DataLoader(valset, batch_size=batch_size,
                         shuffle=True, num_workers=num_workers, pin_memory=True)
testloader = DataLoader(testset, batch_size=batch_size,
                         shuffle=True, num_workers=num_workers, pin_memory=True)

loader_dict = {'train': trainloader, 'val': valloader, 'test': testloader}



In [None]:
evaluate = ModelEvaluation(model, loader_dict)
evaluate.train(epochs=30, batch_size=batch_size)

In [None]:
evaluate.evaluate('train', batch_size)

loss_avg: 0.22081801295280457, iou_avg: 0.6617366075515747, gd_avg: 0.7819307446479797
recall_avg: 0.8226538896560669, precision_avg: 0.7754347920417786, f1_avg: 0.7826881408691406


In [None]:
evaluate.evaluate('val', batch_size)

loss_avg: 0.24651551246643066, iou_avg: 0.6342942714691162, gd_avg: 0.7557283043861389
recall_avg: 0.7921702265739441, precision_avg: 0.7616923451423645, f1_avg: 0.7569184899330139


In [None]:
evaluate.evaluate('test', batch_size)

loss_avg: 0.24115262925624847, iou_avg: 0.640584409236908, gd_avg: 0.7609906196594238
recall_avg: 0.7951751351356506, precision_avg: 0.7647630572319031, f1_avg: 0.7618975043296814
