In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage.transform import resize
from glob import glob
from tqdm import tqdm

import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner

import albumentations as A
from albumentations.augmentations.dropout.grid_dropout import GridDropout
import pickle



In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def RetinaTransform(image):
    # Bringing image to 0-1 range, #ATTENTION: CHECK IF ALL IMAGES min max is 0 and 255
    image = (image / 255)*2-1
    
    
    image = resize(image, (64,64),anti_aliasing=False) # resizing image because original does not fit in memory.
    
    label = image.copy()
    grid_drop = GridDropout(ratio=0.5, unit_size_min=10, unit_size_max=50, fill_value=1, random_offset=True, always_apply=True)
    image = grid_drop(image=image)["image"]
    # Bringing data to CHW format
    # N is a batch size, C denotes a number of channels, 
    # H is a height of input planes in pixels, and W is width in pixels.
    image = image.transpose([2,0,1])
    label = label.transpose([2,0,1]) #No need to transpose lables
    
    #Fixing dtype to avoid runtime error and save memory
    image = torch.tensor(image ,dtype=torch.float32)
    label = torch.tensor(label ,dtype=torch.float32)
    
    return image,label

In [4]:
# Load dataset into memory
if 'image_array' not in globals():
    image_array = []
    data_paths= glob("../data/ODIR-5K/ODIR-5K/Testing Images/*")+glob("../data/ODIR-5K/ODIR-5K/Training Images/*")+glob("../data/REFUGE/Images_Square/*")
    data_paths = data_paths
    local_transform = RetinaTransform


    for image_name in tqdm(data_paths):
        image = io.imread(image_name)
        image,label = local_transform(image)
        image_array.append((image,label))
        
    with open("image_array.pkl", "wb") as f:
        pickle.dump(image_array,f)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9200/9200 [13:36<00:00, 11.27it/s]


In [14]:
# with open("image_array.pkl", "rb") as f:
#     image_array = pickle.load(f)

In [5]:

# normalize the images between -1 and 1
# Tanh as the last layer of the generator output

# In GAN papers, the loss function to optimize G is min (log 1-D), but in practice folks practically use max log D

#     because the first formulation has vanishing gradients early on
#     Goodfellow et. al (2014)

# In practice, works well:

#     Flip labels when training generator: real = fake, fake = real

# the stability of the GAN game suffers if you have sparse gradients
# LeakyReLU = good (in both G and D)


# optim.Adam rules!
#     See Radford et. al. 2015
# Use SGD for discriminator and ADAM for generator



In [7]:
class EMSNETGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1,padding=1)
        self.linear_encoder = nn.Linear(3*64*64, 2048)
        self.linear_decoder = nn.Linear(2048, 3*64*64)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.feature_conv(x)
        x = self.tanh(x)
        x = torch.flatten(x,start_dim=1)
        x = self.linear_encoder(x)
        x = self.tanh(x)
        x = self.linear_decoder(x)
        x = self.tanh(x)
        x = x.reshape(-1,3,64,64)
        return x

In [8]:
class EMSNETDiscriminator(nn.Module):
    def __init__(self):
        super(EMSNETDiscriminator, self).__init__()
        self.ndf = 10
        self.predict = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, self.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.predict(input)

In [9]:
class EMSNETGAN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.generator = EMSNETGenerator()
        self.discriminator = EMSNETDiscriminator()
        self.criterion = nn.BCELoss()
        self.automatic_optimization = False
        
    def forward(self, img):
        return self.generator(img)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, labels = batch
        
        optimizer_g, optimizer_d = self.optimizers()
        
        self.toggle_optimizer(optimizer_g)
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)
        
        
        if optimizer_idx == 0:
            
            real_label_ = torch.ones(labels.size(0),device=device)
            real_output_ = self.discriminator(labels).view(-1)
            err_real = self.criterion(real_output_, real_label_)
            fake_label_ = torch.zeros(imgs.size(0),device=device)
            fake_output_ = self.discriminator(self.generator(imgs)).view(-1)
            err_fake = self.criterion(fake_output_, fake_label_)
            
            self.log("dlos", err_real+err_fake, prog_bar=True)
            return err_real+err_fake

        # train generator
        if optimizer_idx == 1:
            
            #maximize log(D(G(z))) 
            real_label_ = torch.ones(labels.size(0),device=device)
            fake_output_ = self.discriminator(self.generator(imgs)).view(-1)
            err_real = self.criterion(fake_output_, real_label_)
            self.log("glos", err_real, prog_bar=True)
            
            return err_real
        
    def configure_optimizers(self):

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        return [opt_d, opt_g], []


In [10]:
# discriminant = EMSNETDiscriminator()
# generant = EMSNETGenerator()
# loss_ = nn.BCELoss()
# with torch.no_grad():
#     for idx,i in enumerate(val_dataloader):
#         # print(i[0].shape, generant(i[0]).shape)
#         print(i[0].shape, discriminant(i[0]))
#         real_label = torch.ones(i[0].size(0))
#         output = generant(i[1])
#         errD_real = loss_(output, real_label)
#         break

In [11]:
class RetinaDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, data_indices, transform=None):
        
        self.data = [image_array[idx] for idx in data_indices]
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [12]:
train_dataset = RetinaDataset(data_indices=np.arange(0,len(image_array)-50))
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=False, num_workers=8)

val_dataset = RetinaDataset(data_indices=np.arange(len(image_array)-50,len(image_array)))
val_dataloader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=8)

In [13]:
# plt.imshow(train_dataset[0][0].numpy().transpose(1,2,0))

In [14]:
emsnet = EMSNETGAN()

In [15]:
# from pytorch_lightning.callbacks import TQDMProgressBar
# class LitProgressBar(TQDMProgressBar):
#     def init_validation_tqdm(self):
#         bar = super().init_validation_tqdm()
#         bar.set_description("running validation...")
#         return bar

In [16]:
trainer = pl.Trainer(max_epochs=1, accelerator='gpu', gradient_clip_val=0.5)
# tuner = Tuner(trainer)
# tuner.scale_batch_size(emsnet, mode="binsearch")
# trainer.fit(model=emsnet, train_dataloaders=train_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
trainer.fit(model=emsnet, train_dataloaders=train_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


RuntimeError: Training with multiple optimizers is only supported with manual optimization. Set `self.automatic_optimization = False`, then access your optimizers in `training_step` with `opt1, opt2, ... = self.optimizers()`.

In [None]:
# emsnet.generator

In [None]:
emsnet.eval()
with torch.no_grad():
    for idx,i in enumerate(train_dataloader):
        print(i[0].shape,i[1].shape)
        output = emsnet.generator(i[0])
        # print(emsnet.discriminator(output))
        # print(emsnet.discriminator(i[1]))
        plt.imshow(output[0].cpu().detach().numpy().transpose(2,1,0))
        plt.show()
        plt.imshow(i[1][0].permute(1,2,0))
        plt.show()
        plt.imshow(i[0][0].permute(1,2,0))
        plt.show()
        if idx == 4: break

In [29]:
with torch.no_grad():
    torch.cuda.empty_cache()