In [1]:
! pip install  --quiet "lightning-bolts" "torchvision" "torchmetrics"

[K     |████████████████████████████████| 316 kB 3.9 MB/s 
[K     |████████████████████████████████| 397 kB 11.6 MB/s 
[K     |████████████████████████████████| 527 kB 13.8 MB/s 
[K     |████████████████████████████████| 133 kB 20.8 MB/s 
[K     |████████████████████████████████| 952 kB 24.4 MB/s 
[K     |████████████████████████████████| 596 kB 32.7 MB/s 
[K     |████████████████████████████████| 829 kB 7.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 46.3 MB/s 
[K     |████████████████████████████████| 144 kB 45.6 MB/s 
[K     |████████████████████████████████| 94 kB 3.1 MB/s 
[K     |████████████████████████████████| 271 kB 48.4 MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompati

In [9]:
from argparse import ArgumentParser
from pathlib import Path
from typing import Any, List, Optional, Tuple
from warnings import warn

import pytorch_lightning as pl
from pytorch_lightning import Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19

from pl_bolts.callbacks import SRImageLoggerCallback
from pl_bolts.datamodules import TVTDataModule
from pl_bolts.datasets.utils import prepare_sr_datasets

from pl_bolts.models.gans import SRGAN
from pl_bolts.models.gans.srgan.components import SRGANDiscriminator, SRGANGenerator, VGG19FeatureExtractor

# Super-resolution GAN (SRGAN)
Credit: https://github.com/https-deeplearning-ai/GANs-Public  

*Please note that this is meant to introduce more advanced concepts. If you’re up for a challenge, take a look and don’t worry if you can’t follow everything. There is no code to implement—only some cool code for you to learn and run!*

It is recommended that you should already be familiar with:
 - Residual blocks, from [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) (He et al. 2015)
 - Perceptual loss, from [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) (Johnson et al. 2016)
 - VGG architecture, from [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) (Simonyan et al. 2015)

### Goals

In this notebook, you will learn about Super-Resolution GAN (SRGAN), a GAN that enhances the resolution of images by 4x, proposed in [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) (Ledig et al. 2017). You will also implement the architecture and training in full and be able to train it on the CIFAR dataset.

### Background

The authors first train a super-resolution residual network (SRResNet) with standard pixel-wise loss that achieves state-of-the-art metrics. They then insert this as the generator in the SRGAN framework, which is trained with a combination of pixel-wise, perceptual, and adversarial losses.

## SRGAN Submodules

Before jumping into SRGAN, let's first take a look at some components that will be useful later.  

### Parametric ReLU (PReLU)

As you already know, ReLU is one of the simplest activation functions that can be described as

\begin{align*}
    x_{\text{ReLU}} := \max(0, x),
\end{align*}

where negative values of $x$ become thresholded at $0$. However, this stops gradient through these negative values, which can hinder training. The authors of [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/abs/1502.01852) addressed this by introducing a more general ReLU by scaling negative values by some constant $a > 0$:

\begin{align*}
    x_{\text{PReLU}} := \max(0, x) + a * \min(0, x).
\end{align*}

Conveniently, this is implemented in Pytorch as [torch.nn.PReLU](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html)

### Residual Blocks

The residual block, which is relevant in many state-of-the-art computer vision models, is used in all parts of SRGAN and is similar to the ones used in Pix2PixHD (see optional notebook). If you're not familiar with residual blocks, please take a look [here](https://paperswithcode.com/method/residual-block). Now, you'll start by first implementing a basic residual block.

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, feature_maps: int = 64) -> None:
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(feature_maps, feature_maps, kernel_size=3, padding=1),
            nn.BatchNorm2d(feature_maps),
            nn.PReLU(),
            nn.Conv2d(feature_maps, feature_maps, kernel_size=3, padding=1),
            nn.BatchNorm2d(feature_maps),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.block(x)

###  PixelShuffle

Proposed in [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158) (Shi et al. 2016), PixelShuffle, also called sub-pixel convolution, is another way to upsample an image.

PixelShuffle simply reshapes a $r^2C\ x\ H\ x\ W$ tensor into a $C\ x\ rH\ x\ rW$ tensor, essentially trading channel information for spatial information. Instead of convolving with stride $1/r$ as in deconvolution, the authors think about the weights in the kernel as being spaced $1/r$ pixels apart. When sliding this kernel over an input, the weights that fall between pixels aren't activated and don't need need to be calculated. The total number of activation patterns is thus increased by a factor of $r^2$. This operation is illustrated in the figure below.

Don't worry if this is confusing! The algorithm is conveniently implemented as `torch.nn.PixelShuffle` in PyTorch, so as long as you have a general idea of how this works, you're set.

> ![Efficient Sub-pixel CNN](https://github.com/https-deeplearning-ai/GANs-Public/blob/master/SRGAN-PixelShuffle.png?raw=true)
*Efficient sub-pixel CNN, taken from Figure 1 of [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158) (Shi et al. 2016). The PixelShuffle operation (also known as sub-pixel convolution) is shown as the last step on the right.*

## SRGAN Parts

Now that you've learned about the various SRGAN submodules, you can now use them to build the generator and discriminator!

### Generator (SRResNet)

The super-resolution residual network (SRResNet) and the generator are the same thing. The generator network architecture is actually quite simple - just a bunch of convolutional layers, residual blocks, and pixel shuffling layers!

> ![SRGAN Generator](https://github.com/https-deeplearning-ai/GANs-Public/blob/master/SRGAN-Generator.png?raw=true)
*SRGAN Generator, taken from Figure 4 of [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) (Ledig et al. 2017).*

In [4]:
class SRGANGenerator(nn.Module):
    '''
    Generator Class
    Values:
        image_channels: number of channels throughout the generator, a scalar
        num_ps_blocks: number of PixelShuffle blocks, a scalar
        num_res_blocks: number of residual blocks, a scalar
    '''  

    def __init__(
        self,
        image_channels: int,
        feature_maps: int = 64,
        num_res_blocks: int = 16,
        num_ps_blocks: int = 2,
    ) -> None:
        super().__init__()
        # Input block 
        self.input_block = nn.Sequential(
            nn.Conv2d(image_channels, feature_maps, kernel_size=9, padding=4),
            nn.PReLU(),
        )

        # B residual blocks 
        res_blocks = []
        for _ in range(num_res_blocks):
            res_blocks += [ResidualBlock(feature_maps)]

        # k3n64s1
        res_blocks += [
            nn.Conv2d(feature_maps, feature_maps, kernel_size=3, padding=1),
            nn.BatchNorm2d(feature_maps),
        ]
        self.res_blocks = nn.Sequential(*res_blocks)

        # PixelShuffle blocks
        ps_blocks = []
        for _ in range(num_ps_blocks):
            ps_blocks += [
                nn.Conv2d(feature_maps, 4 * feature_maps, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.PReLU(),
            ]
        self.ps_blocks = nn.Sequential(*ps_blocks)

        # Output block 
        self.output_block = nn.Sequential(
            nn.Conv2d(feature_maps, image_channels, kernel_size=9, padding=4),
            nn.Tanh(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_res = self.input_block(x)
        x = x_res + self.res_blocks(x_res)
        x = self.ps_blocks(x)
        x = self.output_block(x)
        return x

### Discriminator

The discriminator architecture is also relatively straightforward, just one big sequential model - see the diagram below for reference!

![SRGAN Generator](https://github.com/https-deeplearning-ai/GANs-Public/blob/master/SRGAN-Discriminator.png?raw=true)
*SRGAN Discriminator, taken from Figure 4 of [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) (Ledig et al. 2017).*

In [5]:
class SRGANDiscriminator(nn.Module):
    def __init__(self, image_channels: int, feature_maps: int = 64) -> None:
        super().__init__()

        self.conv_blocks = nn.Sequential(
            self._make_double_conv_block(image_channels, feature_maps, first_batch_norm=False),
            self._make_double_conv_block(feature_maps, feature_maps * 2),
            self._make_double_conv_block(feature_maps * 2, feature_maps * 4),
            self._make_double_conv_block(feature_maps * 4, feature_maps * 8),
        )

        self.mlp = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(feature_maps * 8, feature_maps * 16, kernel_size=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_maps * 16, 1, kernel_size=1),
            nn.Flatten(),
        )

    def _make_double_conv_block(
        self,
        in_channels: int,
        out_channels: int,
        first_batch_norm: bool = True,
    ) -> nn.Sequential:
        return nn.Sequential(
            self._make_conv_block(in_channels, out_channels, batch_norm=first_batch_norm),
            self._make_conv_block(out_channels, out_channels, stride=2),
        )

    @staticmethod
    def _make_conv_block(
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        batch_norm: bool = True,
    ) -> nn.Sequential:
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_blocks(x)
        x = self.mlp(x)
        return x


In [6]:
class VGG19FeatureExtractor(nn.Module):
    def __init__(self, image_channels: int = 3) -> None:
        super().__init__()

        assert image_channels in [1, 3]
        self.image_channels = image_channels

        vgg = vgg19(pretrained=True)
        self.vgg = nn.Sequential(*list(vgg.features)[:-1]).eval()
        for p in self.vgg.parameters():
            p.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.image_channels == 1:
            x = x.repeat(1, 3, 1, 1)

        return self.vgg(x)

## C. SRGAN

In [7]:
class SRGAN(pl.LightningModule):
  def __init__(
      self,
      image_channels: int = 3,
      feature_maps_gen: int = 64,
      feature_maps_disc: int = 64,
      num_res_blocks: int = 16,
      scale_factor: int = 4,
      generator_checkpoint: Optional[str] = None,
      learning_rate: float = 1e-4,
      scheduler_step: int = 100,
      **kwargs: Any,
  ) -> None:
      """
      Args:
          image_channels: Number of channels of the images from the dataset
          feature_maps_gen: Number of feature maps to use for the generator
          feature_maps_disc: Number of feature maps to use for the discriminator
          num_res_blocks: Number of res blocks to use in the generator
          scale_factor: Scale factor for the images (either 2 or 4)
          generator_checkpoint: Generator checkpoint created with SRResNet module
          learning_rate: Learning rate
          scheduler_step: Number of epochs after which the learning rate gets decayed
      """
      super().__init__()
      self.save_hyperparameters()

      if generator_checkpoint:
          self.generator = torch.load(generator_checkpoint)
      else:
          assert scale_factor in [2, 4]
          num_ps_blocks = scale_factor // 2
          self.generator = SRGANGenerator(image_channels, feature_maps_gen, num_res_blocks, num_ps_blocks)

      self.discriminator = SRGANDiscriminator(image_channels, feature_maps_disc)
      self.vgg_feature_extractor = VGG19FeatureExtractor(image_channels)

  def configure_optimizers(self) -> Tuple[List[torch.optim.Adam], List[torch.optim.lr_scheduler.MultiStepLR]]:
      opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.learning_rate)
      opt_gen = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate)

      sched_disc = torch.optim.lr_scheduler.MultiStepLR(opt_disc, milestones=[self.hparams.scheduler_step], gamma=0.1)
      sched_gen = torch.optim.lr_scheduler.MultiStepLR(opt_gen, milestones=[self.hparams.scheduler_step], gamma=0.1)
      return [opt_disc, opt_gen], [sched_disc, sched_gen]

  def forward(self, lr_image: torch.Tensor) -> torch.Tensor:
      """Generates a high resolution image given a low resolution image.
      Example::
          srgan = SRGAN.load_from_checkpoint(PATH)
          hr_image = srgan(lr_image)
      """
      return self.generator(lr_image)
    
  def training_step(
      self,
      batch: Tuple[torch.Tensor, torch.Tensor],
      batch_idx: int,
      optimizer_idx: int,
  ) -> torch.Tensor:
      hr_image, lr_image = batch

      # Train discriminator
      result = None
      if optimizer_idx == 0:
          result = self._disc_step(hr_image, lr_image)

      # Train generator
      if optimizer_idx == 1:
          result = self._gen_step(hr_image, lr_image)

      return result

  def _disc_step(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor:
      disc_loss = self._disc_loss(hr_image, lr_image)
      self.log("loss/disc", disc_loss, on_step=True, on_epoch=True)
      return disc_loss

  def _gen_step(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor:
      gen_loss = self._gen_loss(hr_image, lr_image)
      self.log("loss/gen", gen_loss, on_step=True, on_epoch=True)
      return gen_loss

  def _disc_loss(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor:
      real_pred = self.discriminator(hr_image)
      real_loss = self._adv_loss(real_pred, ones=True)

      _, fake_pred = self._fake_pred(lr_image)
      fake_loss = self._adv_loss(fake_pred, ones=False)

      disc_loss = 0.5 * (real_loss + fake_loss)

      return disc_loss

  def _gen_loss(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor:
      fake, fake_pred = self._fake_pred(lr_image)

      perceptual_loss = self._perceptual_loss(hr_image, fake)
      adv_loss = self._adv_loss(fake_pred, ones=True)
      content_loss = self._content_loss(hr_image, fake)

      gen_loss = 0.006 * perceptual_loss + 0.001 * adv_loss + content_loss

      return gen_loss

  def _fake_pred(self, lr_image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
      fake = self(lr_image)
      fake_pred = self.discriminator(fake)
      return fake, fake_pred

  @staticmethod
  def _adv_loss(pred: torch.Tensor, ones: bool) -> torch.Tensor:
      target = torch.ones_like(pred) if ones else torch.zeros_like(pred)
      adv_loss = F.binary_cross_entropy_with_logits(pred, target)
      return adv_loss

  def _perceptual_loss(self, hr_image: torch.Tensor, fake: torch.Tensor) -> torch.Tensor:
      real_features = self.vgg_feature_extractor(hr_image)
      fake_features = self.vgg_feature_extractor(fake)
      perceptual_loss = self._content_loss(real_features, fake_features)
      return perceptual_loss

  @staticmethod
  def _content_loss(hr_image: torch.Tensor, fake: torch.Tensor) -> torch.Tensor:
      return F.mse_loss(hr_image, fake)

## Parameters
* image_channels (int) – Number of channels of the images from the dataset

* feature_maps_gen (int) – Number of feature maps to use for the generator

* feature_maps_disc (int) – Number of feature maps to use for the discriminator

* num_res_blocks (int) – Number of res blocks to use in the generator

* scale_factor (int) – Scale factor for the images (either 2 or 4)

* generator_checkpoint (Optional[str]) – Generator checkpoint created with SRResNet module

* learning_rate (float) – Learning rate

* scheduler_step (int) – Number of epochs after which the learning rate gets decayed

In [14]:
datasets=["celeba", "mnist", "stl10"]

DATASET="mnist"
DATA_DIR="./"
AVAIL_GPUS=1

IMAGE_CHANNELS=3
FEATURE_MAPS_GEN=64
FEATURE_MAPS_GEN=64
NUM_RES_BLOCKS=16
SCALE_FACTOR=4
LEARNING_RATE=1e-4
SCHEDULER_STEP=100

In [15]:
GENERATOR_CHECKPOINT = Path(f"./srgan-{DATASET}-scale_factor={SCALE_FACTOR}.pt")
if not GENERATOR_CHECKPOINT.exists():
    warn(
        "No generator checkpoint found. Training generator from scratch."
    )
    GENERATOR_CHECKPOINT = None

  after removing the cwd from sys.path.


In [None]:

datasets = prepare_sr_datasets(DATASET, SCALE_FACTOR, DATA_DIR)
dm = TVTDataModule(*datasets)
model = SRGAN(generator_checkpoint=GENERATOR_CHECKPOINT,
              image_channels=dm.dataset_test.image_channels,
              feature_maps_gen=64,
              feature_maps_disc=64,
              num_res_blocks=NUM_RES_BLOCKS,
              scale_factor=SCALE_FACTOR,
              learning_rate=LEARNING_RATE,
              scheduler_step=SCALE_FACTOR)
trainer = Trainer(gpus=AVAIL_GPUS)
trainer.fit(model, dm)

  "Argument interpolation should be of type InterpolationMode instead of int. "
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type                  | Params
----------------------------------------------------------------
0 | generator             | SRGANGenerator        | 1.5 M 
1 | discriminator         | SRGANDiscriminator    | 5.2 M 
2 | vgg_feature_extractor | VGG19FeatureExtractor | 20.0 M
----------------------------------------------------------------
6.7 M     Trainable params
20.0 M    Non-trainable params
26.8 M    Total params
107.070   Total estimated model params size (MB)
  cpuset_checked))


Training: 0it [00:00, ?it/s]

In [None]:
lr_image="LOW_RES_IMAGE_PATH"
srgan = SRGAN.load_from_checkpoint(PATH)
hr_image = srgan(lr_image)