# Applying Vess2Image on COVID-19 CT

Yiheng Zhou (yz996) | Eva Gao (eyg2) | Qiuyu Zhu (qz258) 

*This notebook applies the Vess2Image model on a novel data set*

In [None]:
from google.colab import drive
drive.mount('/content/gdrive') 

Mounted at /content/gdrive


## Data Preprocessing

In [None]:
!pip install pytorch-lightning

In [None]:
import torch.nn as nn 
import nibabel as nb
import os
import numpy as np
import matplotlib.pyplot as plt 


from scipy import misc 
from IPython.core.displayhook import Float
from typing_extensions import dataclass_transform
import pandas as pd

from torch.utils.data import Dataset, DataLoader

from glob import glob
from pathlib import Path

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms.functional import center_crop
from torchvision.utils import make_grid
from tqdm.auto import tqdm


In [None]:
device = torch.device('cuda')
torch.set_default_dtype(torch.float64)

In [None]:
# define folder paths  
intermediate_data = "/content/gdrive/MyDrive/Attention U-Net/Intermediate Data/"

# Load data train 
og_imgs = intermediate_data + "train.nii"
binary_train_mask = "/content/gdrive/MyDrive/Attention U-Net/Intermediate Data/binary_train_mask.nii"
binary_test_mask = "/content/gdrive/MyDrive/Attention U-Net/Intermediate Data/binary_test_mask.nii" 

## Dataloader

### Define data loading/visualization helper functions and create dataloader 

*Data used is chest CT images and masks from the MedSeg COvID-19 dataset [3]* 

In [None]:
# Data proccessing 
def load_nii(path):
  img_file = nb.load(path)
  imgs = img_file.get_fdata()
  return imgs

def convert_data(data,mask = False):
  temp = np.array(data)
  tensor = torch.tensor(temp).double() 
  tensor = torch.unsqueeze(tensor,0)
  tensor = tensor.permute(3,0,1,2)
  tensor = tensor[:30,:,:,:]
  return tensor 

# img_path, mask_path = FULL path to nii file 
class CustomImageDataset(Dataset):
    def __init__(self, img_path, mask_path):
        self.img = img_path
        self.mask = mask_path 

    def __len__(self):
        img = load_nii(self.img)
        img_nii = convert_data(img)
        return img_nii.shape[0]

    def __getitem__(self, idx):
        img = load_nii(self.img)
        mask = load_nii(self.mask) 
        img_nii = convert_data(img)
        mask_nii = convert_data(mask,True)

        img_slice = img_nii[idx, :, :, :] 
        mask_slice = mask_nii[idx, :, :, :]
        
        return img_slice, mask_slice 

# Data loader - batch size = 1 
from torch.utils.data import DataLoader
train_data = CustomImageDataset(og_imgs, binary_train_mask)
train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True) 

## Model Architecture 

*The Vess2Image is a modified variation of the Pix2Pix conditional GAN model [1], which uses a U-Net based architecture as a generator and a convolutional PatchGAN classifier as a discriminator.* 

### U-Net based generator

*The U-Net architecture consists of a contracting path (**DownSampleConv**) and an expanding path (**UpSampleConv**).*

In [None]:
class DownSampleConv(nn.Module):
    '''
    This class implements the downsampling operations: 2DConvolution-BatchNorm-LeakyReLU
    '''

    def __init__(self, in_channels, out_channels, kernel=4, stride=2, pad=1, batchnorm=True):
        super().__init__()
        self.batchnorm = batchnorm
        self.conv = nn.Conv2d(in_channels, out_channels, kernel, stride, pad)
        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        '''
        Takes in a vector x and returns after downsampling
        '''
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        
        x = self.act(x)
        return x
class UpSampleConv(nn.Module):
    '''
    This class implements the upsampling operations:2DConvolution-BatchNorm
    '''
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        strides=2,
        padding=1,
        batchnorm=True,
        dropout=False
    ):
        super().__init__()
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

       
        if dropout:
            self.drop = nn.Dropout2d(0.5)

    def forward(self, x):
        '''
        Takes in a vector x and returns it after performing upsampling convolution
        '''
        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)

        if self.dropout:
            x = self.drop(x)
        return x

In [None]:
class Generator(nn.Module):

    '''
    This class implements the U-Net generator. 
    It consists of eight encoder blocks (downsampling) 
    followed by seven decoder blocks (upsampling).
    Skip connection is added between the encoder and the decoder states.
    '''
    def __init__(self, in_channels, out_channels): 
        super().__init__()
        # encoders 
        self.enc1 = DownSampleConv(in_channels, 64, batchnorm=False)
        self.enc2 = DownSampleConv(64, 128)
        self.enc3 = DownSampleConv(128, 256)  
        self.enc4 = DownSampleConv(256, 512)
        self.enc5 = DownSampleConv(512, 512)
        self.enc6 = DownSampleConv(512, 512)
        self.enc7 = DownSampleConv(512, 512)
        self.enc8 = DownSampleConv(512, 512, batchnorm = False)

        # decoders
        self.dec1 = UpSampleConv(512, 512, dropout=True) 
        self.dec2 = UpSampleConv(1024, 512, dropout=True) 
        self.dec3 = UpSampleConv(1024, 512, dropout=True) 
        self.dec4 = UpSampleConv(1024, 512) 
        self.dec5 = UpSampleConv(1024, 256) 
        self.dec6 = UpSampleConv(512, 128) 
        self.dec7 = UpSampleConv(256, 64) 

        self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh() 

    def forward(self, x):
        '''
        Parameters: x an input tensor of image 
        Returns x: a tensor of the decoded image
        '''
        # encoders 
        skip1 = self.enc1(x)
        skip2 = self.enc2(skip1) 
        skip3 = self.enc3(skip2) 
        skip4 = self.enc4(skip3) 
        skip5 = self.enc5(skip4) 
        skip6 = self.enc6(skip5) 
        skip7 = self.enc7(skip6) 
        enc_out = self.enc8(skip7) 
      
        # decoders and add skip connections
        d1 = self.dec1(enc_out)  
        d1 = torch.cat((d1, skip7), axis=1) 
        d2 = self.dec2(d1)   
        d2 = torch.cat((d2, skip6), axis=1) 
        d3 = self.dec3(d2)   
        d3 = torch.cat((d3, skip5), axis=1) 
        d4 = self.dec4(d3)   
        d4 = torch.cat((d4, skip4), axis=1) 
        d5 = self.dec5(d4)   
        d5 = torch.cat((d5, skip3), axis=1) 
        d6 = self.dec6(d5) 
        d6 = torch.cat((d6, skip2), axis=1) 

        # last decoder 
        x = self.dec7(d6) 
        x = self.final_conv(x)
        return self.tanh(x)

### The PatchGAN discriminator

In [None]:
class PatchGAN(nn.Module):
    '''
    This class implements a convolutional PatchGAN classifier as a discriminator.
    It consists of four successive downsampling blocks.
    '''
    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
        self.d2 = DownSampleConv(64, 128)
        self.d3 = DownSampleConv(128, 256)
        self.d4 = DownSampleConv(256, 512)
        self.final = nn.Conv2d(512, 1, kernel_size=1)

    def forward(self, x, y):
        '''
        Parameters: 
            x: the generated image
            y: the conditioned image
        Returns x5: a tensor of logits
        '''

        x = torch.cat([x, y], axis=1)
        x1 = self.d1(x)
        x2 = self.d2(x1)
        x3 = self.d3(x2)
        x4 = self.d4(x3)
        x5 = self.final(x4)
        return x5

### Helper Functions: Weight initialization and Output Visualization 

In [None]:
def weights_init(m):
    '''
    Sets the weights of the two Conv2d operations,
    ConvTranspose2d, and BatchNorm2d operation
    '''
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
def display_progress(cond, real, fake, figsize=(10,5)):
    '''
    Displays the ground truth masks, real images, and synthetic images. 
    '''
    cond = cond.detach().cpu().permute(1, 2, 0)
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)

    fig, ax = plt.subplots(1, 3, figsize=figsize)
    ax[0].imshow(cond, cmap = 'gray')
    ax[1].imshow(real, cmap = 'gray')
    ax[2].imshow(fake, cmap = 'gray')
    
    plt.show()

### Vess2Image

In [None]:
d_loss = []
g_loss = [] 
class Pix2Pix(pl.LightningModule): 
    '''This class implements the Vess2Image framework
    using PyTorch Lightning.'''
    def __init__(self, in_channels, out_channels, learning_rate=0.00002, lambda_recon=160, display_step=5):
        # Define key variables
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        self.display_step = display_step
        self.generator = Generator(in_channels, out_channels)
        self.patchgan = PatchGAN(in_channels + out_channels)
        self.generator = self.generator.apply(weights_init)
        self.patchgan = self.patchgan.apply(weights_init) 
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.recon_criterion = nn.L1Loss()

    def _gen_step(self, real_images, conditioned_images):
        # Define the generator training step
        fake_images = self.generator(conditioned_images)
        dis_logits = self.patchgan(fake_images, conditioned_images)
        adversarial_loss = self.adversarial_criterion(dis_logits, torch.ones_like(dis_logits))
        recon_loss = self.recon_criterion(fake_images, real_images)
        lambda_recon = self.hparams.lambda_recon

        return adversarial_loss + (lambda_recon * recon_loss)

    def _dis_step(self, real_images, conditioned_images):
        # Define the discriminator training step
        fake_images = self.generator(conditioned_images).detach()
        fake_logits = self.patchgan(fake_images, conditioned_images)

        real_logits = self.patchgan(real_images, conditioned_images)

        fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
        real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
        return (real_loss + fake_loss) / 2

    def configure_optimizers(self):
        lr = self.hparams.learning_rate
        gen_opt = torch.optim.Adam(self.generator.parameters(), lr=lr,betas = (0.5,0.5) )
        dis_opt = torch.optim.Adam(self.patchgan.parameters(), lr=lr,betas  = (0.5,0.5))
        return [dis_opt, gen_opt]

    def training_step(self, batch, batch_idx):
        # Define training step for Pix2Pix
        real, condition = batch

        # manual optimization
        dis_opt, gen_opt = self.optimizers()

        # Train discriminator
        dis_loss = self._dis_step(real, condition)

        self.log('PatchGAN Loss', dis_loss)
        dis_opt.zero_grad()
        dis_loss.backward()
        dis_opt.step()

        # Train generator
        gen_loss = self._gen_step(real, condition)

        self.log('Generator Loss', gen_loss)
        gen_opt.zero_grad()
        gen_loss.backward()
        gen_opt.step()
        
        # display ground-truth mask, real image, and synthetic image
        if self.current_epoch%self.display_step==0 and batch_idx==0:
            fake = self.generator(condition).detach()
            display_progress(condition[0], real[0], fake[0])
        
        d_loss.append(dis_loss)
        g_loss.append(gen_loss) 

## Training

In [None]:
pix2pix = Pix2Pix(1, 1)
trainer = pl.Trainer(max_epochs=30)
trainer.fit(pix2pix, train_dataloader)

## Citation

[1] Isola, P., Zhu, J., Zhou ,T., & Efros, A.A. (2018). Image-to-image translation with conditional adversarial networks. arXiv, accessed at: https://arxiv.org/abs/1611.07004

[2] Maurya, A. (2021). Pix2Pix-Image to image translation with conditional Adversarial Networks, accessed at: https://librecv.github.io/blog/gans/pytorch/2021/02/13/Pix2Pix-explained-with-code.html

[3] MedSeg. (2020) COVID-19 CT segmentation dataset. *MedSeg: COVID-19*, accessed at: http://medicalsegmentation.com/covid19/ 