# Applying VessGAN on COVID-19 CT

Yiheng Zhou (yz996) | Eva Gao (eyg2) | Qiuyu Zhu (qz258) 

*This notebook applies the VessGAN model on a novel data set*

In [None]:
from google.colab import drive
drive.mount('/content/gdrive') 

## Data Preprocessing

In [None]:
import torch.nn as nn 
import nibabel as nb
import os
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn 
import skimage.io 
import skimage.filters 
from scipy.ndimage import gaussian_filter 
from scipy import misc 
from torchvision import transforms, datasets 
from IPython.core.displayhook import Float
from typing_extensions import dataclass_transform
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset 
from torch import optim as optim 
import math
from torch.utils.data import DataLoader, Dataset 
from PIL import Image 
from google.colab.patches import cv2_imshow 
import torchvision.transforms as T 

## Preprocessing and define a CustomImage dataset 

In [None]:
device = torch.device('cuda')
torch.set_default_dtype(torch.float64)

In [None]:
# define folder paths 
intermediate_data = "/content/gdrive/MyDrive/Attention U-Net/Intermediate Data/" 

# Load data train 
og_imgs = intermediate_data + "train.nii"
binary_train_mask = "/content/gdrive/MyDrive/Attention U-Net/Intermediate Data/binary_train_mask.nii"
binary_test_mask = "/content/gdrive/MyDrive/Attention U-Net/Intermediate Data/binary_test_mask.nii" 

### Define data loading/visualization helper functions and create dataloader 

*Data used is chest CT images and masks from the MedSeg COvID-19 dataset [2]* 

In [None]:
def display_progress(real, fake, figsize=(10,5)): 
    '''
    Displays the ground truth and generated masks. 
    '''
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)

    fig, ax = plt.subplots(1, 2, figsize=figsize)
    ax[0].imshow(real, cmap = 'gray')
    ax[1].imshow(fake, cmap = 'gray')
    
    plt.show()

In [None]:
# Data proccessing 
def load_nii(path):
  img_file = nb.load(path)
  imgs = img_file.get_fdata()
  return imgs

def convert_data(data,mask = False):
  temp = np.array(data)
  tensor = torch.tensor(temp).double() 
  tensor = torch.unsqueeze(tensor,0)
  tensor = tensor.permute(3,0,1,2)
  tensor = tensor[:30,:,:,:]
  return tensor 

# img_path, mask_path = FULL path to nii file 
class CustomImageDataset(Dataset):
    def __init__(self, img_path, mask_path):
        self.img = img_path
        self.mask = mask_path 

    def __len__(self):
        img = load_nii(self.img)
        img_nii = convert_data(img)
        return img_nii.shape[0]

    def __getitem__(self, idx):
        img = load_nii(self.img)
        mask = load_nii(self.mask) 
        img_nii = convert_data(img)
        mask_nii = convert_data(mask,True)

        img_slice = img_nii[idx, :, :, :] 
        mask_slice = mask_nii[idx, :, :, :]
        
        return img_slice, mask_slice 

# Data loader - batch size = 1 
from torch.utils.data import DataLoader
train_data = CustomImageDataset(og_imgs, binary_train_mask)
dataloader = DataLoader(train_data, batch_size=1, shuffle=True) 

### Retrieve one sample from the dataset and visualize 

In [None]:
img, mask = next(iter(dataloader))
display_progress(img.squeeze(0), mask.squeeze(0)) 

## Model Architecture

*The VessGAN is a modified variation of the vanilla GAN model, adapted from class and Goodfellow et al. 2014 [1].* 

### Generator and Discriminator classes

In [None]:
# Create the Generator model class, which will be used to initialize the generator
class Generator(nn.Module):
    '''
    This class implements the generator. 
    It consists of four hidden layers each with a linear layer and a LeakyReLU activation function.
    The last hidden layer uses a Tanh activation function.
    '''
  def __init__(self, input_dim, output_dim): 
    super(Generator,self).__init__() 
    self.hidden_layer1 = nn.Sequential(
        nn.Linear(input_dim, 64),
        nn.LeakyReLU(0.2)
    )
    self.hidden_layer2 = nn.Sequential(
        nn.Linear(64, 128),
        nn.LeakyReLU(0.2)
    )
    self.hidden_layer3 = nn.Sequential(
        nn.Linear(128, 256),
        nn.LeakyReLU(0.2)
    )
    self.hidden_layer4 = nn.Sequential(
        nn.Linear(256, output_dim),
        nn.Tanh()
    )

    # ensure data type consistency by explicit casting 
    for m in self.modules():
      if isinstance(m, nn.Linear):
        m.weight.data = m.weight.data.type(torch.float32)
        m.bias.data = m.bias.data.type(torch.float32)


  def forward(self, x): 
      '''
      Parameters: x: an input tensor of noise 
      Returns: output: a tensor of the generated image
      '''
      x = x.to(device) 
      x = x.float() 
      output = self.hidden_layer1(x)
      output = self.hidden_layer2(output)
      output = self.hidden_layer3(output)
      output = self.hidden_layer4(output)
      return output.to(device)

class Discriminator(nn.Module):
    '''
    This class implements the discriminator.
    It consists of four hidden layers each with a linear layer, 
    LeakyReLU activation function, and dropout. 
    The last hidden layer uses a Sigmoid activation function and no dropout. 
    '''
    def __init__(self, input_dim, output_dim=1):
        super(Discriminator, self).__init__()
        self.hidden_layer1 = nn.Sequential(
            nn.Linear(input_dim, 256), 
            nn.LeakyReLU(0.2), 
            nn.Dropout(0.3)
        )

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden_layer3 = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden_layer4 = nn.Sequential(
            nn.Linear(64, output_dim),
            nn.Sigmoid()
        )

        # ensure data type consistency by explicit casting 
        for m in self.modules():
          if isinstance(m, nn.Linear):
            m.weight.data = m.weight.data.type(torch.float32)
            m.bias.data = m.bias.data.type(torch.float32)

    def forward(self, x): 
        '''
        Parameters: x: the generated image 
        Returns output: a tensor of logits
        ''' 
        x = x.to(device)
        x = x.float() 
        output = self.hidden_layer1(x)
        output = self.hidden_layer2(output)
        output = self.hidden_layer3(output)
        output = self.hidden_layer4(output)
        return output.to(device)


## Training

### Define training procedures for generator and discriminator 

In [None]:
# Training procedures 
lossf = nn.BCELoss()
def train_generator(batch_size): 
    generator_optimizer.zero_grad()
    noise = torch.randn(batch_size,100).to(device)
    fake_img = generator(noise).to(device)
    dis_out = discriminator(fake_img).to(device)
    y = torch.ones(batch_size, 1).to(device)

    loss = lossf(dis_out.float(), y.float())

    loss.backward()

    generator_optimizer.step()

    return torch.sum(loss)/batch_size

def train_discriminator(batch_size, images): 
    discriminator_optimizer.zero_grad()
    noise = torch.randn(batch_size,100).to(device)
    images = images.view(images.size(0), -1) 
    fake_img = generator(noise) 
  
    true_y = torch.ones(batch_size, 1).to(device) 
    fake_y = torch.zeros(batch_size, 1).to(device) 

    fake_out = discriminator(fake_img)
    true_out = discriminator(images)

    loss_fake = lossf(fake_out.float(), fake_y.float())
    loss_true = lossf(true_out.float(), true_y.float())

    loss = (loss_fake + loss_true)/2

    loss.backward()

    discriminator_optimizer.step()

    return torch.sum(loss)/batch_size 

### Define training parameters

In [None]:
training_parameters = {
    "img_size": 512,
    "n_epochs": 200,
    "batch_size": 1,
    "learning_rate_generator": 0.00002,
    "learning_rate_discriminator": 0.00001,
}

### Initialize generator and discriminator models and their optimizers 

In [None]:
# initialize models  
discriminator = Discriminator(262144,1).to(device) 
generator = Generator(100,262144).to(device)
# initialize optimizers 
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=training_parameters['learning_rate_discriminator'])
generator_optimizer = optim.Adam(generator.parameters(), lr=training_parameters['learning_rate_generator']) 

### Training loop and output visualization 

In [None]:
d_loss = []
g_loss = []
for epoch in range(training_parameters['n_epochs']): 
    G_loss = []  
    D_loss = []
    for batch, (imgs, labels) in enumerate(dataloader):
        batch_size = 1 
        lossG = train_generator(batch_size) 
        G_loss.append(lossG)
        lossD = train_discriminator(batch_size, labels) 
        lossD = train_discriminator(batch_size, labels) # second train 
        D_loss.append(lossD) 
        # Display a batch of generated images and print the loss 
        if(batch == 19): 
          noise = torch.randn(batch_size, 100).to(device) 
          fake = generator(noise).cpu().view(batch_size, 512, 512)
          # display generated images and save losses 
          d_loss.append(torch.mean(torch.FloatTensor(D_loss)))
          g_loss.append(torch.mean(torch.FloatTensor(G_loss)))
          display_progress(labels[0], fake.unsqueeze(0)[0]) 

### Training loss

In [None]:
import pandas as pd 
g_loss = torch.tensor(g_loss) 
d_loss = torch.tensor(d_loss) 

g_loss = g_loss.detach().cpu().numpy() 
d_loss = d_loss.detach().cpu().numpy() 

df = pd.DataFrame(list(zip(g_loss, d_loss)), columns=["G_Loss", "D_Loss"])
df.head()

In [None]:
import seaborn as sns 
line_loss = sns.lineplot(df)

In [None]:
'''
Used to empty cache to avoid CUDA memory errors 
''' 
# import torch
# torch.cuda.empty_cache()

## Citation

[1] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., & Bengio, Y. (2014). Generative adversarial networks. *Communications of the ACM* 63(11), 139-144., accessed at :
https://dl.acm.org/doi/pdf/10.1145/3422622

[2] MedSeg. (2020) COVID-19 CT segmentation dataset. *MedSeg: COVID-19*, accessed at: http://medicalsegmentation.com/covid19/ 