In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage.transform import resize
from glob import glob
from tqdm import tqdm


import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner

import albumentations as A
from albumentations.augmentations.dropout.grid_dropout import GridDropout
import pickle

import torchvision

import cv2



In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def RetinaTransform(image):
    # resizing image because original does not fit in memory.

    image = resize(image, (64,64),anti_aliasing=False)
    # Bringing image to 0-1 range, #ATTENTION: CHECK IF ALL IMAGES min max is 0 and 255
    image = (image / 255)
    
    label = image.copy()
    # grid_drop = GridDropout(ratio=0.5, unit_size_min=10, unit_size_max=50, fill_value=1, random_offset=True, always_apply=True)
    # image = grid_drop(image=image)["image"]
    # Bringing data to CHW format
    # N is a batch size, C denotes a number of channels, 
    # H is a height of input planes in pixels, and W is width in pixels.
    image = image.transpose([2,0,1])
    label = label.transpose([2,0,1]) #No need to transpose lables
    
    #Fixing dtype to avoid runtime error and save memory
    image = torch.tensor(image ,dtype=torch.float32)
    label = torch.tensor(label ,dtype=torch.float32)
    
    return image,label

In [None]:
# Load dataset into memory
if 'image_array' not in globals():
    image_array = []
    data_paths= glob("../data/ODIR-5K/ODIR-5K/Testing Images/*")+glob("../data/ODIR-5K/ODIR-5K/Training Images/*")+glob("../data/REFUGE/Images_Square/*")
    data_paths = data_paths
    local_transform = RetinaTransform


    for image_name in tqdm(data_paths):
        image = cv2.imread(image_name)
        image,label = local_transform(image)
        image_array.append((image,label))
        
    with open("image_array.pkl", "wb") as f:
        pickle.dump(image_array,f)

In [4]:
with open("image_array.pkl", "rb") as f:
    image_array = pickle.load(f)

In [5]:
image_array = [(i*2-1,j*2-1) for i,j in image_array]

In [6]:
# plt.imshow((image_array[0][0].numpy().transpose(1,2,0)+1)*255/2)

In [7]:

# normalize the images between -1 and 1
# Tanh as the last layer of the generator output

# In GAN papers, the loss function to optimize G is min (log 1-D), but in practice folks practically use max log D

#     because the first formulation has vanishing gradients early on
#     Goodfellow et. al (2014)

# In practice, works well:

#     Flip labels when training generator: real = fake, fake = real

# the stability of the GAN game suffers if you have sparse gradients
# LeakyReLU = good (in both G and D)


# optim.Adam rules!
#     See Radford et. al. 2015
# Use SGD for discriminator and ADAM for generator



In [8]:
def SuperConv2d(x, supr_conv_arch):
    return torch.concat([nn.Conv2d(*conf).to(device)(x) for conf in supr_conv_arch], dim=1)

In [9]:
# class EMSNETGenerator(nn.Module):
#     def __init__(self):
#         super().__init__()
#         # self.super_conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, stride=1,padding=1)
#         self.linear_encoder = nn.Linear(30*64*64, 128)
#         self.batch_norm2d = nn.BatchNorm2d(30)
#         self.batch_norm1d = nn.BatchNorm1d(128)
#         self.lrelu = nn.LeakyReLU(0.2)
#         self.linear_decoder = nn.Sequential(
#             nn.Linear(128,256),
#             self.lrelu,
#             nn.Linear(256,512),
#             self.lrelu,
#             nn.Linear(512,1024),
#             self.lrelu,
#             nn.Linear(1024,3*64*64),
#         )
#         self.tanh = nn.Tanh()
#         self.sigmoid = nn.Sigmoid()
#         self.dropout= nn.Dropout(p=0.8)
        
#     def forward(self, x):
#         x = SuperConv2d(x,[[3,10,3,1,1],[3,10,5,1,2],[3,10,7,1,3]])
#         x = self.lrelu(x)
#         x = self.batch_norm2d(x)
#         x = torch.flatten(x,start_dim=1)
#         x = self.linear_encoder(x)
#         x = self.batch_norm1d(x)
#         x = self.lrelu(x)
#         x = self.linear_decoder(x)
#         # x = self.batch_norm1d(x)
#         x = self.tanh(x)
#         x = x.reshape(-1,3,64,64)
#         return x

In [10]:
from torchvision.models import resnet50, ResNet50_Weights
resnet_model = resnet50(weights="IMAGENET1K_V1").to(device)

In [11]:
class EMSNETGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet_model = resnet50(weights="IMAGENET1K_V1")
        self.batch_norm1d = nn.BatchNorm1d(1000)
        self.batch_norm2d = nn.BatchNorm2d(3)
        self.lrelu = nn.LeakyReLU(0.2)
        self.linear_decoder = nn.Linear(1000,3*64*64)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.resnet_model(x)
        x = self.batch_norm1d(x)
        x = self.lrelu(x)
        x = torch.flatten(x,start_dim=1)
        x = self.linear_decoder(x)
        x = x.reshape(-1,3,64,64)
        x = self.batch_norm2d(x)   
        x = self.tanh(x)
        return x

In [12]:
class EMSNETDiscriminator(nn.Module):
    def __init__(self):
        super(EMSNETDiscriminator, self).__init__()
        self.ndf = 20
        self.predict = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, self.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.predict(input)

In [13]:
class EMSNETGAN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.generator = EMSNETGenerator()
        self.discriminator = EMSNETDiscriminator()
        self.automatic_optimization = False
        self.cntr = 0 
    def forward(self, img):
        return self.generator(img)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch):
        imgs, labels = batch
        
        optimizer_g, optimizer_d = self.optimizers()
        
        self.toggle_optimizer(optimizer_d)
        
        valid = torch.ones(imgs.size(0))
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(labels).view(-1), valid)

        fake = torch.zeros(imgs.size(0))
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self.generator(imgs)).view(-1), fake)

        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()

        self.untoggle_optimizer(optimizer_d)
        
        
        
        self.toggle_optimizer(optimizer_g)
        
        self.generated_imgs = self.generator(imgs)
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid((sample_imgs+1)*255/2)
        self.logger.experiment.add_image("generated_images", grid, self.cntr)
        self.cntr+=1
        
        
        valid = torch.ones(imgs.size(0))
        valid = valid.type_as(imgs)
        
        #pixel-wise loss
        pixel_loss = self.adversarial_loss(torch.flatten((self.generator(imgs)+1)/2,start_dim=1),torch.flatten((imgs+1)/2,start_dim=1))
        # gd_loss = self.adversarial_loss(self.discriminator(self.generator(imgs)).view(-1), valid)
        gd_loss=0
        g_loss = (pixel_loss+gd_loss)/2
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)
        
        
    def configure_optimizers(self):

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        return [opt_g, opt_d], []


In [14]:
# discriminant = EMSNETDiscriminator()
# generant = EMSNETGenerator()
# loss_ = nn.BCELoss()
# with torch.no_grad():
#     for idx,i in enumerate(val_dataloader):
#         # print(i[0].shape, generant(i[0]).shape)
#         print(i[0].shape, discriminant(i[0]))
#         real_label = torch.ones(i[0].size(0))
#         output = generant(i[1])
#         errD_real = loss_(output, real_label)
#         break

In [15]:
class RetinaDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, data_indices, transform=None):
        
        self.data = [image_array[idx] for idx in data_indices]
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [16]:
train_dataset = RetinaDataset(data_indices=np.arange(0,len(image_array)-50))
# train_dataset = RetinaDataset(data_indices=np.arange(0,1200))
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=8)

val_dataset = RetinaDataset(data_indices=np.arange(len(image_array)-50,len(image_array)))
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=8)

In [17]:
# plt.imshow(train_dataset[0][0].numpy().transpose(1,2,0))

In [18]:
emsnet = EMSNETGAN()

In [19]:
# from pytorch_lightning.callbacks import TQDMProgressBar
# class LitProgressBar(TQDMProgressBar):
#     def init_validation_tqdm(self):
#         bar = super().init_validation_tqdm()
#         bar.set_description("running validation...")
#         return bar

In [20]:
trainer = pl.Trainer(max_epochs=10, accelerator='gpu')
# tuner = Tuner(trainer)
# tuner.scale_batch_size(emsnet, mode="binsearch")
# trainer.fit(model=emsnet, train_dataloaders=train_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
trainer.fit(model=emsnet, train_dataloaders=train_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                | Params
------------------------------------------------------
0 | generator     | EMSNETGenerator     | 37.9 M
1 | discriminator | EMSNETDiscriminator | 272 K 
------------------------------------------------------
38.1 M    Trainable params
0         Non-trainable params
38.1 M    Total params
152.529   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [95]:
# emsnet.generator

In [30]:
from torchvision.models import resnet50, ResNet50_Weights

# Using pretrained weights:
# resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet_model = resnet50(weights="IMAGENET1K_V1").to(device)

In [33]:
with torch.no_grad():
    for idx,i in enumerate(train_dataloader):
        print(resnet_model(i[0].to(device)).shape)
        break

torch.Size([128, 1000])


In [63]:
emsnet.eval()
with torch.no_grad():
    for idx,i in enumerate(train_dataloader):
        print(i[0].shape,i[1].shape)
        output = emsnet.generator(i[0].to(device))
        # print(emsnet.discriminator(output))
        # print(emsnet.discriminator(i[1]))
        plt.imshow(output[0].cpu().detach().numpy().transpose(2,1,0)*255/2+255/2)
        plt.show()
        plt.imshow(i[1][0].permute(1,2,0)*255/2+255/2)
        plt.show()
        plt.imshow(i[0][0].permute(1,2,0)*255/2+255/2)
        plt.show()
        if idx == 4: break

torch.Size([128, 3, 64, 64]) torch.Size([128, 3, 64, 64])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper__cudnn_batch_norm)

In [37]:
with torch.no_grad():
    torch.cuda.empty_cache()