In [None]:
from PIL import Image
from torch.utils import data
import glob
import torchvision.transforms as transforms
import numpy as np

import torch
from torch import nn, optim
from torchvision.models import vgg19
from torch.autograd import Variable
from torch.utils.data import DataLoader

from PIL import Image
import matplotlib.pyplot as plt

## Extract CelebA Dataset. [Source](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)

In [None]:
!pip install gdown
!gdown 'https://drive.google.com/uc?id=13IvLXTmWc4hj4Tx4YpBW1LesN30UZusj'

In [None]:
!unzip -oq celeb_dataset.zip

## Setup Dataset loading for PyTorch

In [None]:
class Dataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, data_folder, transform_lr, transform_hr, stage):
        'Initialization'
        file_list = glob.glob('{}/*'.format(data_folder))
        n = len(file_list)
        train_size = np.floor(n * 0.8).astype(int)
        self.images = file_list[:train_size] if stage == 'train' else file_list[train_size:]

        self.transform_lr = transforms.Compose(transform_lr)
        self.transform_hr = transforms.Compose(transform_hr)

    def __len__(self):
        'Denotes the total number of samples'
        # return len(self.images)
        return 12800 # limit training dataset size to run in this notebook


    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample

        # Load data and get label
        hr = Image.open(self.images[index])

        return self.transform_lr(hr), self.transform_hr(hr)

## SRGAN Model Definition

![Architecture](https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-19_at_11.13.45_AM_zsF2pa7.png)

[Source](https://arxiv.org/abs/1609.04802v5)

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class ResidualBlock(nn.Module):
    def __init__(self, n_output=64, k_size=3, stride=1, padding=1):
        """
        Residual block as defined in paper. In -> Conv -> BN -> PReLu -> Conv -> BN + In
        """
        super(ResidualBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(n_output, n_output, k_size, stride, padding),
            nn.BatchNorm2d(n_output),
            nn.PReLU(),
            nn.Conv2d(n_output, n_output, k_size, stride, padding),
            nn.BatchNorm2d(n_output),
        )

    def forward(self, x):
        return x + self.model(x)

class ShuffleBlock(nn.Module):
    def __init__(self, n_input, n_output, k_size=3, stride=1, padding=1):
        """
        Shuffle block containing PixelShuffle layer. Conv -> PixelShuffle -> BN
        """
        super(ShuffleBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(n_input, n_output, k_size, stride, padding),  # N, 256, H, W
            nn.PixelShuffle(2),  # N, 64, 2H, 2W
            nn.PReLU(),
        )
        '''
        Input: :math:`(N, C * upscale_factor^2, H, W)`
        Output: :math:`(N, C, H * upscale_factor, W * upscale_factor)`
        '''

    def forward(self, x):
        return self.model(x)


# n_fmap = number of feature maps,
# B = number of cascaded residual blocks.
class Generator(nn.Module):
    def __init__(self, n_input=3, n_output=3, n_fmap=64, num_residual_blocks=16):
        super(Generator, self).__init__()

        self.l1 = nn.Sequential(
            nn.Conv2d(n_input, n_fmap, 9, 1, 4),
            nn.PReLU(),
        )

        # A cascaded of B residual blocks.
        self.residual_blocks = []
        for _ in range(num_residual_blocks):
            self.residual_blocks.append(ResidualBlock(n_fmap))
        self.residual_blocks = nn.Sequential(*self.residual_blocks)

        self.l2 = nn.Sequential(
            nn.Conv2d(n_fmap, n_fmap, 3, 1, 1),
            nn.BatchNorm2d(n_fmap),
        )

        self.px = nn.Sequential(
            ShuffleBlock(64, 256),
            ShuffleBlock(64, 256),
        )

        self.conv_final = nn.Sequential(
            nn.Conv2d(64, n_output, 9, 1, 4),
            nn.Tanh(),
        )


    def forward(self, img_in):
        out_1 = self.l1(img_in)
        out_2 = self.residual_blocks(out_1)
        out_3 = out_1 + self.l2(out_2)
        out_4 = self.px(out_3)
        return self.conv_final(out_4)


class Discriminator(nn.Module):
    def __init__(self, lr_channels=3):
        super(Discriminator, self).__init__()

        def convblock(n_input, n_output, k_size=3, stride=1, padding=1, bn=True):
            block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
            if bn:
                block.append(nn.BatchNorm2d(n_output))
            block.append(nn.LeakyReLU(0.2, inplace=True))
            return block

        self.conv = nn.Sequential(
            *convblock(lr_channels, 64, 3, 1, 1, bn=False),
            *convblock(64, 64, 3, 2, 1),
            *convblock(64, 128, 3, 1, 1),
            *convblock(128, 128, 3, 2, 1),
            *convblock(128, 256, 3, 1, 1),
            *convblock(256, 256, 3, 2, 1),
            *convblock(256, 512, 3, 1, 1),
            *convblock(512, 512, 3, 2, 1),
        )

        self.fc = nn.Sequential(
            nn.Linear(512 * 16 * 16, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out_1 = self.conv(img)
        out_1 = out_1.view(img.size(0), -1)
        out_2 = self.fc(out_1)
        return out_2


class VGGFeatures(nn.Module):
    def __init__(self):
        super(VGGFeatures, self).__init__()
        model = vgg19(pretrained=True)

        children = list(model.features.children())
        max_pool_indices = [index for index, m in enumerate(children) if isinstance(m, nn.MaxPool2d)]
        target_features = children[:max_pool_indices[4]]
        '''
          We use vgg-5,4 which is the layer output after 5th conv 
          and right before the 4th max pool.
        '''
        self.features = nn.Sequential(*target_features)
        for p in self.features.parameters():
            p.requires_grad = False

        '''
        # VGG means and stdevs on pretrained imagenet
        mean = -1 + Variable(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        std = 2*Variable(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

        # This is for cuda compatibility.
        self.register_buffer('mean', mean)
        self.register_buffer('std', std)
        '''

    def forward(self, input):
        # input = (input - self.mean) / self.std
        output = self.features(input)
        return output

## Set parameters

In [None]:
class Config:
  def __init__(self) -> None:
      self.n_epochs = 10
      self.batch_size = 16
      self.lr = 0.0001
      self.b1 = 0.9
      self.b2 = 0.999
      self.img_size = 256
      self.data_folder = 'img_align_celeba'

In [None]:
opt = Config()

transform_image_hr = [
    transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]
transform_image_lr = [
    transforms.Resize((opt.img_size//4, opt.img_size//4), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]
data_train = Dataset(opt.data_folder, transform_image_lr, transform_image_hr, stage='train')
data_val = Dataset(opt.data_folder, transform_image_lr, transform_image_hr, stage='val')

params = {'batch_size': opt.batch_size, 'shuffle': True}
dataloader_train = DataLoader(data_val, **params)

params = {'batch_size': 5, 'shuffle': True}
dataloader_val = DataLoader(data_val, **params)

cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# Loss
gan_loss = nn.BCELoss()
content_loss = nn.MSELoss()

generator = Generator()
discriminator = Discriminator()
vgg = VGGFeatures()

optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Loss record.
g_losses = []
d_losses = []
epochs = []
loss_legend = ['Discriminator', 'Generator']

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    vgg = vgg.cuda()
    gan_loss = gan_loss.cuda()
    content_loss = content_loss.cuda()

# Network weight init
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

from tqdm import tqdm
for epoch in range(opt.n_epochs):
    for i, (batch_lr, batch_hr) in tqdm(enumerate(dataloader_train), total=len(dataloader_train), desc='Step'):

        real = Variable(Tensor(batch_lr.size(0), 1).fill_(1), requires_grad=False)
        fake = Variable(Tensor(batch_lr.size(0), 1).fill_(0), requires_grad=False)

        imgs_real_lr = Variable(batch_lr.type(Tensor))
        imgs_real_hr = Variable(batch_hr.type(Tensor))

        # == Discriminator update == #
        optimizer_D.zero_grad()

        imgs_fake_hr = Variable(generator(imgs_real_lr.detach()))

        d_loss = gan_loss(discriminator(imgs_real_hr), real) + gan_loss(discriminator(imgs_fake_hr), fake)

        d_loss.backward()
        optimizer_D.step()

        # == Generator update == #
        imgs_fake_hr = generator(imgs_real_lr)

        optimizer_G.zero_grad()

        g_loss = (1/12.75) * content_loss(vgg(imgs_fake_hr), vgg(imgs_real_hr.detach())) + 1e-3 * gan_loss(discriminator(imgs_fake_hr), real)

        g_loss.backward()
        optimizer_G.step()

        epochs.append(epoch + i/len(dataloader_train))
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

In [None]:
g_losses_2 = [g_losses[int(i * len(g_losses) / 10)] for i in range(10)]
d_losses_2 = [d_losses[int(i * len(d_losses) / 10)] for i in range(10)]

In [None]:
plt.plot(range(10), g_losses_2,  marker='o', label='Generator Loss')
plt.plot(range(10), d_losses_2, marker='o', color = 'green', label='Discriminator Loss')
plt.title('Generator & Discriminator Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(loc='best')
plt.grid(True)

In [None]:
# Generate 5 images from the validation set.
batch_lr, batch_hr = next(iter(dataloader_val))

imgs_val_lr = Variable(batch_lr.type(Tensor))
imgs_val_hr = Variable(batch_hr.type(Tensor))
imgs_fake_hr = generator(imgs_val_lr).detach().data

# For visualization purposes.
imgs_val_lr = torch.nn.functional.upsample(imgs_val_lr, size=(imgs_fake_hr.size(2), imgs_fake_hr.size(3)), mode='bilinear')

imgs_val_lr = imgs_val_lr.mul_(Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).add_(Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
imgs_val_hr = imgs_val_hr.mul_(Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).add_(Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
imgs_fake_hr = imgs_fake_hr.add_(torch.abs(torch.min(imgs_fake_hr))).div_(torch.max(imgs_fake_hr)-torch.min(imgs_fake_hr))
fake_val = torch.cat((imgs_val_lr, imgs_val_hr, imgs_fake_hr), dim=2)


In [None]:
fig, ax = plt.subplots(nrows=len(batch_lr), ncols=3, figsize=(20,35))
for i in range(len(batch_lr)):
    ax[i, 0].imshow(imgs_val_lr[i].detach().cpu().permute(1,2,0))
    ax[i, 0].set_xticks([]), ax[i, 0].set_yticks([])
    ax[i, 0].title.set_text('Original LR Image')

    ax[i, 1].imshow(imgs_val_hr[i].detach().cpu().permute(1,2,0))
    ax[i, 1].set_xticks([]), ax[i, 1].set_yticks([])
    ax[i, 1].title.set_text('Original HR Image')

    ax[i, 2].imshow(imgs_fake_hr[i].detach().cpu().permute(1,2,0))
    ax[i, 2].set_xticks([]), ax[i, 2].set_yticks([])
    ax[i, 2].title.set_text('SRGAN Generated Image')


In [None]:
x = torch.randn(1, 3, 64, 64, requires_grad=True, dtype=torch.float32).to('cuda')
torch.onnx.export(generator,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "super_resolution.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['l1'],   # the model's input names
                  output_names = ['conv_final'], # the model's output names
                  # dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                  #               'output' : {0 : 'batch_size'}}
)