# Vess2Image 

Yiheng Zhou (yz996) | Eva Gao (eyg2) | Qiuyu Zhu (qz258) 

*Translating vessel masks into retinal images*

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Data Preprocessing

In [None]:
import cv2
import os
import numpy as np
from PIL import Image

import torch
import torch.nn as nn 
from torch.utils.data import DataLoader, Dataset 
import torchvision.transforms as T

In [None]:
original_images = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/original/images'
original_masks = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/original/masks'
original_eyeball = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/original/eyeballs'

preprocessed_images = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/preprocessed/images'
preprocessed_masks = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/preprocessed/masks'
preprocessed_combined_masks = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/preprocessed/combined_masks'

### Load retinal images & binary vessel tree masks and save as PNG files 

In [None]:
Nimgs = 20
channels = 1
height = 584
width = 565

imgs = np.empty((Nimgs,height,width,1))
lbls = np.empty((Nimgs,height,width))
eyeballs = np.empty((Nimgs,height,width))

# data preprocessing
i=0
for file in sorted(os.listdir(original_images)):
  img = cv2.imread(original_images+'/'+file, 1)
  img = np.expand_dims(img, axis=0)
  img = np.moveaxis(img,0, -1)
  imgs[i] = img
  i += 1

i=0
for file in sorted(os.listdir(original_masks)):
  g_truth = Image.open(original_masks +'/'+ file)
  lbls[i] = np.asarray(g_truth)
  i += 1

i=0
for file in sorted(os.listdir(original_eyeball)):
  e_truth = Image.open(original_eyeball +'/'+ file)
  eyeballs[i] = np.asarray(e_truth)
  i += 1



print(imgs.shape)
print(lbls.shape)

for i in range(20):
  cv2.imwrite(preprocessed_images + '/' + str(i) + '.png', imgs[i])

for i in range(20):
  cv2.imwrite(preprocessed_masks + '/' + str(i) + '.png', lbls[i])

### Load eyeball background masks, combine with vessel tree structures, and save the combined masks as PNG files

In [None]:
test = eyeballs.copy()

preprocessed_combined_masks = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/preprocessed/combined_masks'

for i in range(test.shape[0]):
  for l in range(test.shape[1]):
    for w in range(test.shape[2]):
      if lbls[i,l,w] == 255:
         test[i,l,w] = 100

In [None]:
plt.imshow(test[0,:,:],cmap = 'gray')

In [None]:
preprocessed_combined_masks = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/preprocessed/combined_masks'

 
for i in range(20):
    cv2.imwrite(preprocessed_combined_masks + '/' + str(i) + '.png', test[i])

## Dataloader 

### Create a Custom Dataset for retinal images and their corresponding masks (eyeball + vessel)

*Data used is retinal images and masks from the DRIVE: Grand Challenge [3]* 

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, path):
        self.imgs = sorted(os.listdir(path+'images'))
        self.lbls = sorted(os.listdir(path+'combined_masks'))

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):

        transform = T.Resize(size = (512, 512))

        img = self.imgs[idx]
        img = Image.open(path+'images/'+ img)
        img = transform(img)
        img = np.asarray(img)
        img = torch.tensor(img).unsqueeze(0).float()
        
        lbl = self.lbls[idx]
        lbl = Image.open(path +'combined_masks/'+ lbl)
        lbl = transform(lbl)
        lbl = np.asarray(lbl)
        lbl = torch.tensor(lbl).unsqueeze(0).float()
        
        return img, lbl

In [None]:
path = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/preprocessed/'
dataset = CustomImageDataset(path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

### Retrieve one sample from the dataset and visualize 

In [None]:
img, mask = next(iter(dataloader))

In [None]:
img.shape, mask.shape

In [None]:
from google.colab.patches import cv2_imshow
cv2_imshow(img.numpy()[0][0])
cv2_imshow(mask.numpy()[0][0])

## Model Architecture

*The Vess2Image is a modified variation of the Pix2Pix conditional GAN model [1], which uses a U-Net based architecture as a generator and a convolutional PatchGAN classifier as a discriminator.* 

In [None]:
!pip install pytorch-lightning

In [None]:
import os
from glob import glob
from pathlib import Path

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms.functional import center_crop
from torchvision.utils import make_grid
from tqdm.auto import tqdm

### U-Net based Generator

*The U-Net architecture consists of a contracting path (**DownSampleConv**) and an expanding path (**UpSampleConv**).*

In [None]:
class DownSampleConv(nn.Module):
    '''
    This class implements the downsampling operations: 2DConvolution-BatchNorm-LeakyReLU
    '''

    def __init__(self, in_channels, out_channels, kernel=4, stride=2, pad=1, batchnorm=True):
        super().__init__()
        self.batchnorm = batchnorm
        self.conv = nn.Conv2d(in_channels, out_channels, kernel, stride, pad)
        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        '''
        Takes in a vector x and returns after downsampling
        '''

        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        
        x = self.act(x)
        return x

In [None]:
class UpSampleConv(nn.Module):
    '''
    This class implements the upsampling operations: 2DConvolution-BatchNorm

    '''
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        strides=2,
        padding=1,
        batchnorm=True,
        dropout=False
    ):
        super().__init__()
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if dropout:
            self.drop = nn.Dropout2d(0.5)

    def forward(self, x):
    '''
    Takes in a vector x and returns it after performing upsampling convolution
    '''

        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)

        if self.dropout:
            x = self.drop(x)
        return x

In [None]:
class Generator(nn.Module):
    '''
    This class implements the U-Net generator. 
    It consists of eight encoder blocks (downsampling) followed by seven decoder blocks (upsampling).
    Skip connection is added between the encoder and the decoder states.
    '''
    def __init__(self, in_channels, out_channels): 
        super().__init__()
        # encoders 
        self.enc1 = DownSampleConv(in_channels, 64, batchnorm=False)
        self.enc2 = DownSampleConv(64, 128)
        self.enc3 = DownSampleConv(128, 256)  
        self.enc4 = DownSampleConv(256, 512)
        self.enc5 = DownSampleConv(512, 512)
        self.enc6 = DownSampleConv(512, 512)
        self.enc7 = DownSampleConv(512, 512)
        self.enc8 = DownSampleConv(512, 512, batchnorm = False)

        # decoders
        self.dec1 = UpSampleConv(512, 512, dropout=True) 
        self.dec2 = UpSampleConv(1024, 512, dropout=True) 
        self.dec3 = UpSampleConv(1024, 512, dropout=True) 
        self.dec4 = UpSampleConv(1024, 512) 
        self.dec5 = UpSampleConv(1024, 256) 
        self.dec6 = UpSampleConv(512, 128) 
        self.dec7 = UpSampleConv(256, 64) 

        self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh() 

    def forward(self, x):
        '''
        Parameters: x: an input tensor of image 
        Returns x: a tensor of the decoded image
        '''

        # encoders 
        skip1 = self.enc1(x)
        skip2 = self.enc2(skip1) 
        skip3 = self.enc3(skip2) 
        skip4 = self.enc4(skip3) 
        skip5 = self.enc5(skip4) 
        skip6 = self.enc6(skip5) 
        skip7 = self.enc7(skip6) 
        enc_out = self.enc8(skip7) 
      
        # decoders and add skip connections
        d1 = self.dec1(enc_out)  
        d1 = torch.cat((d1, skip7), axis=1) 
        d2 = self.dec2(d1)   
        d2 = torch.cat((d2, skip6), axis=1) 
        d3 = self.dec3(d2)   
        d3 = torch.cat((d3, skip5), axis=1) 
        d4 = self.dec4(d3)   
        d4 = torch.cat((d4, skip4), axis=1) 
        d5 = self.dec5(d4)   
        d5 = torch.cat((d5, skip3), axis=1) 
        d6 = self.dec6(d5) 
        d6 = torch.cat((d6, skip2), axis=1) 

        # last decoder 
        x = self.dec7(d6) 
        x = self.final_conv(x)
        return self.tanh(x)

###  PatchGAN Discriminator

In [None]:
class PatchGAN(nn.Module):
    '''
    This class implements a convolutional PatchGAN classifier as a discriminator.
    It consists of four successive downsampling blocks.
    '''
    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
        self.d2 = DownSampleConv(64, 128)
        self.d3 = DownSampleConv(128, 256)
        self.d4 = DownSampleConv(256, 512)
        self.final = nn.Conv2d(512, 1, kernel_size=1)

    def forward(self, x, y):
        '''
        Parameters: 
            x: the generated image
            y: the conditioned image
        Returns x5: a tensor of logits
        '''
        
        x = torch.cat([x, y], axis=1)
        x1 = self.d1(x)
        x2 = self.d2(x1)
        x3 = self.d3(x2)
        x4 = self.d4(x3)
        x5 = self.final(x4)
        return x5

### Helper Functions: Weight initialization and Output Visualization 

In [None]:
def weights_init(m):
    '''
    Sets the weights of the two Conv2d operations,
    ConvTranspose2d, and BatchNorm2d operation
    '''
  
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
def display_progress(cond, real, fake, figsize=(10,5)):
    '''
    Displays the ground truth masks, real images, and synthetic images.
    '''
    cond = cond.detach().cpu().permute(1, 2, 0)
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)

    fig, ax = plt.subplots(1, 3, figsize=figsize)
    ax[0].imshow(cond, cmap = 'gray')
    ax[1].imshow(real, cmap = 'gray')
    ax[2].imshow(fake, cmap = 'gray')
    
    plt.show()

### Vess2Image

In [None]:
d_loss = []
g_loss = [] 
class Pix2Pix(pl.LightningModule): 
    '''This class implements the Vess2Image framework
    using PyTorch Lightning.'''
    def __init__(self, in_channels, out_channels, learning_rate=0.00002, lambda_recon=160, display_step=5):
        super().__init__()
        # Define key variables
        self.save_hyperparameters()
        self.automatic_optimization = False # activate manual optimization
        self.display_step = display_step
        self.generator = Generator(in_channels, out_channels)
        self.patchgan = PatchGAN(in_channels + out_channels)
        self.generator = self.generator.apply(weights_init)
        self.patchgan = self.patchgan.apply(weights_init) 
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.recon_criterion = nn.L1Loss()

    def _gen_step(self, real_images, conditioned_images):
        # Define the generator training step
        fake_images = self.generator(conditioned_images)
        dis_logits = self.patchgan(fake_images, conditioned_images)
        adversarial_loss = self.adversarial_criterion(dis_logits, torch.ones_like(dis_logits))
        recon_loss = self.recon_criterion(fake_images, real_images)
        lambda_recon = self.hparams.lambda_recon

        return adversarial_loss + (lambda_recon * recon_loss)

    def _dis_step(self, real_images, conditioned_images):
        # Define the discriminator training step
        fake_images = self.generator(conditioned_images).detach()
        fake_logits = self.patchgan(fake_images, conditioned_images)

        real_logits = self.patchgan(real_images, conditioned_images)

        fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
        real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
        return (real_loss + fake_loss) / 2

    def configure_optimizers(self):
        lr = self.hparams.learning_rate
        gen_opt = torch.optim.Adam(self.generator.parameters(), lr=lr,betas = (0.5,0.5) )
        dis_opt = torch.optim.Adam(self.patchgan.parameters(), lr=lr,betas  = (0.5,0.5))
        return [dis_opt, gen_opt]

    def training_step(self, batch, batch_idx):
        # Define training step for Pix2Pix
        real, condition = batch
        
        # manual optimization
        dis_opt, gen_opt = self.optimizers()
        
        # Train discriminator 
        dis_loss = self._dis_step(real, condition)

        self.log('PatchGAN Loss', dis_loss)
        dis_opt.zero_grad()
        dis_loss.backward()
        dis_opt.step()

        # Train generator
        gen_loss = self._gen_step(real, condition)

        self.log('Generator Loss', gen_loss)
        gen_opt.zero_grad()
        gen_loss.backward()
        gen_opt.step()
        
        # display ground-truth mask, real image, and synthetic image
        if self.current_epoch%self.display_step==0 and batch_idx==0:
            fake = self.generator(condition).detach()
            display_progress(condition[0], real[0], fake[0])
        
        d_loss.append(dis_loss)
        g_loss.append(gen_loss) 

## Training

In [None]:
path = '/content/drive/MyDrive/Deep Learning Group Project/DRIVE/preprocessed/'
dataset = CustomImageDataset(path)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

In [None]:
pix2pix = Pix2Pix(1, 1)
trainer = pl.Trainer(max_epochs=30)
trainer.fit(pix2pix, dataloader)

### Training Loss

In [None]:
mean_d = []
mean_g = []

def get_mean_loss(losses, outlist, size):
  temp = 0 
  for i in range(len(losses)):
    temp += losses[i]
    if (i+1)%size == 0:
      outlist.append(temp/size)
  return torch.tensor(outlist)

mean_d = get_mean_loss(d_loss, mean_d, 4)
mean_g = get_mean_loss(g_loss, mean_g, 4)

mean_d = mean_d.detach().cpu().numpy()
mean_g = mean_g.detach().cpu().numpy()

print(mean_d)
print(mean_g)

In [None]:
import pandas as pd 
df = pd.DataFrame(list(zip(mean_g, mean_d)), columns=["G_Loss", "D_Loss"])
df.head()

In [None]:
import seaborn as sns 
line_loss = sns.lineplot(df)

## Citation

The Vess2Image model is adapted from parts of Maurya’s PyTorch Pix2Pix notebook [2]. Additionally, the model definitions for the Pix2Pix network are taken from the Image-to-image translation with conditional adversarial networks paper [1]. 

[1] Isola, P., Zhu, J., Zhou ,T., & Efros, A.A. (2018). Image-to-image translation with conditional adversarial networks. arXiv, accessed at: https://arxiv.org/abs/1611.07004

[2] Maurya, A. (2021). Pix2Pix-Image to image translation with conditional Adversarial Networks, accessed at: https://librecv.github.io/blog/gans/pytorch/2021/02/13/Pix2Pix-explained-with-code.html

[3] Grand Challenge. (2023). DRIVE: digital retinal images for vessel extraction. *Grand Challenge*, accessed at: https://drive.grand-challenge.org/ 