## Data loading and preprocessing

In [14]:
# !pip install monai
# !pip install tensorboard-plugin-3d
# !pip install pynvml
# !pip install einops

# !unzip /content/drive/MyDrive/X-Ray_segmentation.zip

# %load_ext tensorboard
# %tensorboard --logdir runs

import os
import torch
import torchvision
import matplotlib.pyplot as plt
import logging
import einops

from torch import nn
from torchvision import transforms
from torchvision.io import read_image
from monai.networks.nets import UNet, BasicUNet, FlexibleUNet
from monai.networks.nets import VNet, UNETR, SwinUNETR
from monai.metrics import compute_dice, compute_iou, CumulativeAverage
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/swinunetr')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = logging.getLogger('Xray_logger')
logger.setLevel(logging.DEBUG)
file_log = logging.FileHandler('xray.log')
file_log.setLevel(logging.DEBUG)
logger.addHandler(file_log)
logger.propagate = False

In [2]:
def collect_addresses():
  ''' collect paths to images from given folders '''
  
  folder = '/content/X-Ray_segmentation/Covid-xray/'

  files_dict = {'Train': [], 'Test': [], 'Val': []}

  for eval_method in files_dict:
    link = folder + eval_method + '/'
    image_files = os.listdir(link + 'images/')

    for image_file in image_files:
      image = link + 'images/' + image_file
      mask = link + 'masks/' + image_file
      files_dict[eval_method].append((image, mask))
  
  return files_dict

In [3]:
class XRay:
  ''' load and preprocess X-Ray images '''

  def __init__(self, image_path, mask_path):
    self.image_path = image_path
    self.mask_path = mask_path
  
  def get_images(self):
    image, mask = read_image(self.image_path), read_image(self.mask_path)
    image, mask = self.to_standart_format(image, mask)

    return (image, mask)
  
  def to_standart_format(self, image, mask):
    ''' transform to standart image format '''
    
    if image.shape[1] > 256:
      resize = torchvision.transforms.Resize(256)
      image, mask = resize(image), resize(mask)

    elif image.shape[1] < 256:
      pad = torchvision.transforms.Pad(256-image.shape[1])
      image, mask = pad(image), pad(mask)
    
    image = image/255
    mask = mask/255

    if image.shape[0] > 1:
      image = image[0]
    
    return (image, mask)

In [4]:
class XRayDataset(Dataset):
  ''' create Dataset object from one of the sets
      params:
        data_dict: dictionary of filenames for each set
        mode: which set to use
        transforms: list of transforms to apply
  '''

  def __init__(self, data_dict, mode='Train', transforms=None):
    self.data_addresses = data_dict[mode]
    self.transforms = transforms

  def __len__(self):
    return len(self.data_addresses)
  
  def __getitem__(self, idx):
    image_path, mask_path = self.data_addresses[idx]
    image, mask = XRay(image_path, mask_path).get_images()

    if self.transforms:
      randtransf = transforms.RandomChoice(self.transforms) # pick random augmentation method
      # apply augmentation to image and mask simultaneously
      combined = torch.cat((image.unsqueeze(0), mask.unsqueeze(0)), 0)
      image, mask = randtransf(combined)

    return (image, mask)

## Model training

In [5]:
class ModelEvaluation:
  ''' class for training and evaluation of a given model
      params:
        model: an object of a model to evaluate
        loader_dict: dictionary of running modes and their DataLoader objects
  '''

  def __init__(self, model, loader_dict):
    self.model = model
    self.loader_dict = loader_dict

    weight = torch.tensor([4], device=device) # increase weight of positive instances
    self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=weight)
    self.optimizer = torch.optim.Adam(self.model.parameters())
    self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')

  
  def train(self, epochs=1, batch_size=1):
    ''' run training loop '''

    self.model.train()
    torch.backends.cudnn.benchmark = True
    for epoch in range(epochs):
      loss_cumul = CumulativeAverage()
      iou_cumul = CumulativeAverage()
      dice_cumul = CumulativeAverage()
      recall_cumul = CumulativeAverage()
      precision_cumul = CumulativeAverage()
      f1_cumul = CumulativeAverage()
      for batch, (xray, mask) in enumerate(self.loader_dict['train']):
        xray = xray.to(device)
        mask = mask.to(device)

        self.optimizer.zero_grad()
        model_out = self.model(xray)
        loss = self.loss_fn(model_out, mask)

        loss.backward()
        self.optimizer.step()

        model_out = nn.Sigmoid()(model_out)
        dice, iou, recall, precision, f1 = self.compute_metrics(model_out, mask)
        loss_cumul.append(loss, count=batch_size)

        iou_cumul.append(iou)
        dice_cumul.append(dice)
        recall_cumul.append(recall)
        precision_cumul.append(precision)
        f1_cumul.append(f1)

        if batch % 10 == 0:
          writer.add_images('model_out/train', (model_out > 0.5), global_step=batch)
          writer.add_images('mask/train', mask, global_step=batch)
          logger.debug(f'DEBUG| location: ModelEvalutation.train | epoch: {epoch}, batch: {batch}, loss: {loss.item()}, IOU: {iou}, Dice: {dice}')
      
      loss_avg = loss_cumul.aggregate()
      iou_avg = iou_cumul.aggregate()
      dice_avg = dice_cumul.aggregate()
      recall_avg = recall_cumul.aggregate()
      precision_avg = precision_cumul.aggregate()
      f1_avg = f1_cumul.aggregate()

      writer.add_scalar('Loss_AVG/train', loss_avg, epoch)
      writer.add_scalar('IOU_AVG/train', iou_avg, epoch)
      writer.add_scalar('Dice_AVG/train', dice_avg, epoch)
      writer.add_scalar('Recall_AVG/train', recall_avg, epoch)
      writer.add_scalar('Precision_AVG/train', precision_avg, epoch)
      writer.add_scalar('F1_AVG/train', f1_avg, epoch)
  
      logger.debug(f'DEBUG| location: ModelEvalutation.train | loss_avg: {loss_avg}, iou_avg: {iou_avg}, , dice_avg: {dice_avg}')
      self.scheduler.step(loss)

    writer.flush()
  

  def evaluate(self, mode='val', batch_size=1):
    ''' run evaluation loop '''

    self.model.eval()
    loss_cumul = CumulativeAverage()
    iou_cumul = CumulativeAverage()
    dice_cumul = CumulativeAverage()
    recall_cumul = CumulativeAverage()
    precision_cumul = CumulativeAverage()
    f1_cumul = CumulativeAverage()
    with torch.no_grad():
      for batch, (xray, mask) in enumerate(self.loader_dict[mode]):
        xray = xray.to(device)
        mask = mask.to(device)
        model_out = self.model(xray)
        loss = self.loss_fn(model_out, mask)

        model_out = nn.Sigmoid()(model_out)
        dice, iou, recall, precision, f1 = self.compute_metrics(model_out, mask)
        loss_cumul.append(loss, count=batch_size)

        iou_cumul.append(iou)
        dice_cumul.append(dice)
        recall_cumul.append(recall)
        precision_cumul.append(precision)
        f1_cumul.append(f1)

        if batch % 10 == 0:
          writer.add_images(f'model_out/{mode}', (model_out > 0.5), global_step=batch)
          writer.add_images(f'mask/{mode}', mask, global_step=batch)
      
    loss_avg = loss_cumul.aggregate()
    iou_avg = iou_cumul.aggregate()
    dice_avg = dice_cumul.aggregate()
    recall_avg = recall_cumul.aggregate()
    precision_avg = precision_cumul.aggregate()
    f1_avg = f1_cumul.aggregate()
            
    print(f'loss_avg: {loss_avg}, iou_avg: {iou_avg}, dice_avg: {dice_avg}')
    print(f'recall_avg: {recall_avg}, precision_avg: {precision_avg}, f1_avg: {f1_avg}')

    writer.flush()
  
  def to_monai_form(self, y_pred, y):
    ''' transform xray and mask to binary tensor '''
    
    y_pred = y_pred > 0.5
    y = y > 0.5
    return (y_pred, y)

  def compute_metrics(self, y_pred, y):
    ''' compute Monai metrics '''

    y_pred, y = self.to_monai_form(y_pred, y)
    Dice = compute_dice(y_pred, y, ignore_empty=False).mean()
    IOU = compute_iou(y_pred, y, ignore_empty=False).mean()

    recall, precision, f1 = self.additional_metrics(y_pred, y)
    
    metrics = [float(Dice), float(IOU), float(recall),
               float(precision), float(f1)]

    return metrics
  
  def additional_metrics(self, y_pred, y):
    ''' compute additional metrics '''

    inter = (y_pred * y).sum(dim=[1,2,3])
    
    recall = (inter + 1)/(y.sum(dim=[1,2,3]) + 1)
    precision = (inter + 1)/(y_pred.sum(dim=[1,2,3]) + 1)
    f1 = 2*((precision*recall)/(precision+recall))

    return recall.mean(), precision.mean(), f1.mean()

## Execution

In [6]:
model = SwinUNETR(in_channels=1, out_channels=1, img_size=(256,256), drop_rate=0.5, spatial_dims=2, use_checkpoint=True).to(device)
# model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2)).to(device)
# model = VNet(in_channels=1, out_channels=1, spatial_dims=2).to(device)
# model = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, dropout=0.5).to(device)
# model = UNETR(in_channels=1, out_channels=1, img_size=(256,256), dropout_rate=0.5, spatial_dims=2).to(device)
# model = FlexibleUNet(in_channels=1, out_channels=1, backbone='efficientnet-b0', spatial_dims=2).to(device)

In [7]:
augmentations = [transforms.RandomRotation(180), transforms.RandomAffine(180),
                 transforms.ElasticTransform(), transforms.RandomHorizontalFlip(p=0.6),
                 transforms.RandomVerticalFlip(p=0.6), transforms.GaussianBlur(3)]

In [8]:
data_dict = collect_addresses()

trainset = XRayDataset(data_dict, 'Train')
valset = XRayDataset(data_dict, 'Val')
testset = XRayDataset(data_dict, 'Test')

batch_size = 32
num_workers = 4

trainloader = DataLoader(trainset, batch_size=batch_size,
                         shuffle=True, num_workers=num_workers, pin_memory=True)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=True,
                       num_workers=num_workers, pin_memory=True)
testloader = DataLoader(testset, batch_size=batch_size,
                        num_workers=num_workers, pin_memory=True)

loader_dict = {'train': trainloader, 'val': valloader, 'test': testloader}

In [9]:
# load pretrained state
# model.load_state_dict(torch.load('/content/xray_model_3.pt'))

evaluate = ModelEvaluation(model, loader_dict)
evaluate.train(epochs=15, batch_size=batch_size)

In [10]:
evaluate.evaluate('val', batch_size)



loss_avg: 0.23560521006584167, iou_avg: 0.5917662382125854, dice_avg: 0.648518443107605
recall_avg: 0.8311728835105896, precision_avg: 0.7285675406455994, f1_avg: 0.6581042408943176


In [11]:
evaluate.evaluate('test', batch_size)

loss_avg: 0.25243592262268066, iou_avg: 0.600612461566925, dice_avg: 0.6573601961135864
recall_avg: 0.8254415392875671, precision_avg: 0.7375897765159607, f1_avg: 0.665016770362854


In [12]:
# torch.save(model.state_dict(), '/content/xray_model.pt')

## Inference

In [13]:
loaded_model = SwinUNETR(in_channels=1, out_channels=1, img_size=(256,256), spatial_dims=2).to(device)
loaded_model.load_state_dict(torch.load('/content/xray_model_3.pt'))
loaded_model.eval()

FileNotFoundError: ignored

In [None]:
def image_preprocessing(image_path):
  image = read_image(image_path)
  
  if image.shape[1] > 256:
    resize = torchvision.transforms.Resize(256)
    image = resize(image)

  elif image.shape[1] < 256:
    pad = torchvision.transforms.Pad(256-image.shape[1])
    image = pad(image)
  
  image = image/255

  if image.shape[0] > 1:
    image = image[0]
  
  return image

In [None]:
image_path = '/content/X-Ray_segmentation/Covid-xray/Test/images/covid_1702.png'
image = image_preprocessing(image_path)
model_out = loaded_model(image.unsqueeze(0).to(device))
model_out = model_out.detach().cpu()