In [20]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torch.optim as optim
from torchvision.utils import save_image
import albumentations as A
from PIL import Image
import cv2
from albumentations.pytorch.transforms import ToTensorV2
from tqdm import tqdm
import os

In [2]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, is_disc = False, use_act = True, use_bn = True, **kwargs):
        super(Conv, self).__init__()
        self.use_act = use_act
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias = not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (nn.LeakyReLU(0.2, inplace = True) if is_disc else nn.PReLU(num_parameters = out_channels))

    def forward(self, x):
        return self.act(self.bn(self.conv(x))) if self.use_act else self.bn(self.conv(x))

In [3]:
class Upsample(nn.Module):
    def __init__(self, in_channels, upscaling_factor = 2):
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * upscaling_factor ** 2, 3, 1, 1)
        self.pixelshuffle = nn.PixelShuffle(upscaling_factor)
        self.act = nn.PReLU(num_parameters = in_channels)

    def forward(self, x):
        return self.act(self.pixelshuffle(self.conv(x)))

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block1 = Conv(in_channels, in_channels, is_disc = False, use_act = True, use_bn = True, kernel_size = 3, stride = 1, padding = 1)
        self.block2 = Conv(in_channels, in_channels, is_disc = False, use_act = False, use_bn = True, kernel_size = 3, stride = 1, padding = 1)

    def forward(self, x):
        return self.block2(self.block1(x)) + x
        

In [5]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3, channels = 64, num_blocks = 16):
        super(Generator, self).__init__()
        self.start = Conv(in_channels, channels, is_disc = False, use_act = True, use_bn = False, kernel_size = 9, stride = 1, padding = 4)
        self.residuals = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_blocks)])
        self.block = Conv(channels, channels, is_disc = False, use_act = False, use_bn = True, kernel_size = 3, stride = 1, padding = 1)
        self.upsample = nn.Sequential(
            Upsample(channels, 2),
            Upsample(channels, 2),
            nn.Conv2d(channels, in_channels, kernel_size = 9, stride = 1, padding = 4),
            nn.Tanh()
        )

    def forward(self, x):
        initial = self.start(x)
        x = self.residuals(initial)
        x = self.block(x) + initial
        return self.upsample(x)
            

In [6]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, channels = 64):
        super(Discriminator, self).__init__()
        self.block = nn.Sequential(
            Conv(in_channels, channels, is_disc = True, use_act = True, use_bn = False, kernel_size = 3, stride = 1, padding = 1),
            Conv(channels, channels, is_disc = True, use_act = True, use_bn = True, kernel_size = 3, stride = 2, padding = 1),
            Conv(channels, channels * 2, is_disc = True, use_act = True, use_bn = True, kernel_size = 3, stride = 1, padding = 1),
            Conv(channels * 2, channels * 2, is_disc = True, use_act = True, use_bn = True, kernel_size = 3, stride = 2, padding = 1),
            Conv(channels * 2, channels * 4, is_disc = True, use_act = True, use_bn = True, kernel_size = 3, stride = 1, padding = 1),
            Conv(channels * 4, channels * 4, is_disc = True, use_act = True, use_bn = True, kernel_size = 3, stride = 2, padding = 1),
            Conv(channels * 4, channels * 8, is_disc = True, use_act = True, use_bn = True, kernel_size = 3, stride = 1, padding = 1),
            Conv(channels * 8, channels * 8, is_disc = True, use_act = True, use_bn = True, kernel_size = 3, stride = 2, padding = 1),
        )
        self.linear = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(36 * channels * 8, channels * 16),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(channels * 16, 1)
        )

    def forward(self, x):
        return self.linear(self.block(x))
        

In [7]:
def test():
    img = torch.randn((3, 3, 50, 50))
    gen = Generator(3, 64, 16)
    disc = Discriminator(3, 64)
    gen_out = gen(img)
    disc_out = disc(img)
    print(gen_out.shape)
    print(disc_out.shape)
test()

torch.Size([3, 3, 200, 200])
torch.Size([3, 1])


In [8]:
vgg = models.vgg19(weights = models.VGG19_Weights.DEFAULT)
vgg

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [9]:
#phi 5, 4 = 4th conv after activation before 5th maxpool -> after 35

In [10]:
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.MSELoss = nn.MSELoss()
        self.vgg = models.vgg19(weights = models.VGG19_Weights.DEFAULT).features[: 36].eval().to('cuda' if torch.cuda.is_available() else 'cpu')

        for param in vgg.parameters():
            param.requires_grad = False 

    def forward(self, gen_img, real_img):
        return self.MSELoss(self.vgg(gen_img), self.vgg(real_img))
        
        

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 16
epochs = 10000
gen_lr = 1e-4
disc_lr = 1e-4
high_res = 96
low_res = high_res // 4

high_res_dir = r'D:\Codes\MLDS\SRGAN\DIV2K_train_HR\DIV2K_train_HR'

In [12]:
low_res_transform = A.Compose([A.Resize(height = low_res, width = low_res, interpolation = 3),
                               A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
                               A.pytorch.transforms.ToTensorV2()])

high_res_transform = A.Compose([A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
                                A.pytorch.transforms.ToTensorV2()])

augmentation = A.Compose([A.RandomCrop(height = high_res, width = high_res)
                          ])

test_transform = A.Compose([A.Normalize(mean = [0, 0, 0], std = [1, 1, 1]),
                            A.pytorch.transforms.ToTensorV2()])

In [13]:
class Data(Dataset):
    def __init__(self, high_res_dir, augmentation, low_res_transform, high_res_transform):
        self.high_res_dir = high_res_dir
        self.augmentation = augmentation
        self.low_res_transform = low_res_transform
        self.high_res_transform = high_res_transform
        self.high_res_images = sorted(os.listdir(high_res_dir))

    def __len__(self):
        return len(self.high_res_images)

    def __getitem__(self, idx):
        high_res_image_path = os.path.join(self.high_res_dir, self.high_res_images[idx])
        image = np.array(Image.open(high_res_image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.augmentation(image = image)['image']
        low_res_img = self.low_res_transform(image = image)['image']
        high_res_img = self.high_res_transform(image = image)['image']
        return low_res_img, high_res_img
        

In [14]:
def datatest():
    dataset = Data(high_res_dir, augmentation, low_res_transform, high_res_transform)
    dataloader = DataLoader(dataset, batch_size = batch_size)
    for low_res, high_res in dataloader:
        print(low_res.shape)
        print(high_res.shape)
        break
datatest()

torch.Size([16, 3, 24, 24])
torch.Size([16, 3, 96, 96])


In [15]:
disc = Discriminator(3, 64).to(device)
gen = Generator(3, 64, 16).to(device)

dataset = Data(high_res_dir, augmentation, low_res_transform, high_res_transform)
loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

gen_optimizer = optim.Adam(gen.parameters(), lr = gen_lr)
disc_optimizer = optim.Adam(disc.parameters(), lr = disc_lr)

vggloss = VGGLoss()
advloss = nn.BCEWithLogitsLoss()

test_image_path = 'tiger.png'

In [16]:
def plot_examples(image_path, gen, test_transform):
    gen.eval()
    image = Image.open(image_path)
    with torch.no_grad():
        upscaled_img = gen(test_transform(image=np.asarray(image))["image"].unsqueeze(0).to(device))
    save_image(upscaled_img * 0.5 + 0.5, 'test.png')
    gen.train()

In [17]:
def show_images(image_path, gen, test_transform):
    gen.eval()
    image = Image.open(image_path)
    with torch.no_grad():
        upscaled_img = gen(test_transform(image = np.asarray(image))["image"].unsqueeze(0).to(device))
        upscaled_img = upscaled_img.squeeze(0).cpu()
        upscaled_img = upscaled_img.permute(1, 2, 0).numpy()
        upscaled_img = np.clip(upscaled_img, 0, 1)
    gen.train()
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(image)
    ax[0].set_title("Real Image")
    ax[0].axis('off')

    ax[1].imshow(upscaled_img)
    ax[1].set_title("Generated Image")
    ax[1].axis('off')
    plt.show()

In [18]:
def train(loader, gen, disc, gen_optimizer, disc_optimizer, vggloss, advloss, device):
    for epoch in range(epochs):
        for batch_idx, (low_res_img, high_res_img) in enumerate(tqdm(loader)):
            low_res_img, high_res_img = low_res_img.to(device), high_res_img.to(device)
            #disc training
            fake = gen(low_res_img)
            disc_real = disc(high_res_img)
            disc_fake = disc(fake.detach())
            disc_loss_real = advloss(disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real))

            disc_loss_fake = advloss(disc_fake, torch.zeros_like(disc_fake))
            disc_loss = disc_loss_fake + disc_loss_real

            disc_optimizer.zero_grad()
            disc_loss.backward()
            disc_optimizer.step()

        #train gen
            disc_fake = disc(fake)
            adv_loss = 1e-3 * advloss(disc_fake, torch.ones_like(disc_fake))
            vgg_loss = 0.006 * vggloss(fake, high_res_img)
            gen_loss = adv_loss + vgg_loss

            gen_optimizer.zero_grad()
            gen_loss.backward()
            gen_optimizer.step()
        print(f'Epoch: {epoch}')
        print(f'genloss: {gen_loss} \t discloss: {disc_loss}')
        plot_examples(test_image_path, gen, test_transform)

In [19]:
train(loader, gen, disc, gen_optimizer, disc_optimizer, vggloss, advloss, device)

100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [01:39<00:00,  1.98s/it]


Epoch: 0
genloss: 0.007854166440665722 	 discloss: 0.32232749462127686


  return F.conv2d(input, weight, bias, self.stride,
100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [01:31<00:00,  1.83s/it]


Epoch: 1
genloss: 0.005060066934674978 	 discloss: 0.6687721610069275


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [01:31<00:00,  1.83s/it]


Epoch: 2
genloss: 0.003890642896294594 	 discloss: 0.7465443015098572


KeyboardInterrupt: 