#### Libraries

In [28]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torchinfo

import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy as np
import matplotlib.pyplot as plt

#### Learning Module

In [29]:
class Learner(pl.LightningModule):
    def __init__(self, generator: nn.Module, critic: nn.Module, lr: float, gp_weight: float, critic_iters: int, img_save_path: str):
        super().__init__()

        self.generator = generator
        self.critic = critic

        self.lr = lr
        self.gp_weight = gp_weight
        self.critic_iters = critic_iters

        # for manual backward
        self.automatic_optimization = False

        self.img_save_path = img_save_path

        # sample z for logging
        logging_z = torch.randn((4, 100, 1, 1), device=self.device)
        self.register_buffer("logging_z", logging_z)

    def forward(self, x: torch.Tensor):
        return self.generator(x)

    def _gradient_penalty(self, x_real: torch.Tensor, x_fake: torch.Tensor):
        batch_size = x_real.shape[0]

        # generate epsilon for each batch
        epsilon = torch.rand((batch_size, 1, 1, 1), device=self.device)

        # interpolated image, epsilon of size [b, 1, 1, 1], x [b, c, h, w], epsilon is broadcasting
        x_interpolated = epsilon * x_real + (1 - epsilon) * x_fake

        crt_interpolated = self.critic(x_interpolated)

        # calculate gradients
        gradients = torch.autograd.grad(
            outputs=crt_interpolated,
            inputs=x_interpolated,
            grad_outputs=torch.ones_like(crt_interpolated),
            create_graph=True,
            retain_graph=True
        )[0]

        gradients = gradients.view(batch_size, -1)

        # calculating gradient norm
        grad_norm = torch.norm(gradients, p=2, dim=1)
        # gradient penalty
        gp = torch.mean((grad_norm - 1) ** 2)

        return gp

    def training_step(self, batch, batch_idx):
        # get optimizers
        opt_generator, opt_critic = self.optimizers()

        # get training data
        x_real, _ = batch

        ##################################
        # CRITIC
        ##################################

        # critic update iterations per one generator update
        for _ in range(self.critic_iters):
            # sample z
            z = torch.randn((x_real.shape[0], 100, 1, 1), device=self.device)

            # generate fake img
            x_fake = self.generator(z)

            crt_real = self.critic(x_real)
            crt_fake = self.critic(x_fake)

            crt_wasserstein_distance = -torch.mean(crt_real - crt_fake)
            gp = self._gradient_penalty(x_real, x_fake)

            crt_loss = crt_wasserstein_distance + self.gp_weight * gp

            opt_critic.zero_grad()
            self.manual_backward(crt_loss)
            opt_critic.step()

        ##################################
        # GENERATOR
        ##################################

        # sample z
        z = torch.randn((x_real.shape[0], 100, 1, 1), device=self.device)

        # generate fake img
        x_fake = self.generator(z)
        crt_fake = self.critic(x_fake)

        gen_loss = -torch.mean(crt_fake)

        opt_generator.zero_grad()
        self.manual_backward(gen_loss)
        opt_generator.step()

    @torch.no_grad()
    def train_epoch_end(self):
        # generate fake img
        generated = self.generator(self.logging_z)
        grid = torchvision.utils.make_grid(generated)

        # save generated image grid
        torchvision.utils.save_image(grid, f"{self.img_save_path}/img_epoch_{self.current_epoch}.jpg")

    def configure_optimizers(self):
        opt_generator = torch.optim.Adam(self.generator.parameters(), self.lr)
        opt_critic = torch.optim.Adam(self.critic.parameters(), self.lr)

        return opt_generator, opt_critic

#### Model

In [30]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, upscale: bool = False, **kwargs):
        """
        Conv Block for Generator and Critic architectures

        Args:
            in_channels (int): input channels
            out_channels (int): output channels
            upscale (bool, optional): if True upscales image by a factor of 2. Defaults to False.
        """
        super().__init__()

        self.upscale = upscale

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )

    def forward(self, x: torch.Tensor):
        # upscaling img if true
        if self.upscale:
            x = F.interpolate(x, scale_factor=2.0)
        
        return self.conv(x)

In [31]:
class Generator(nn.Module):
    def __init__(self, in_channels: int, intermediate_channels: List[int], out_channels: int):
        """
        Generator model

        Args:
            in_channels (int): input channels
            intermediate_channels (List[int]): list containing number of intermediate channels in each layer of the generator
            out_channels (int): output channels
        """
        super().__init__()

        self.in_channels = in_channels
        self.intermediate_channels = intermediate_channels
        self.out_channels = out_channels

        self.generator = self._build_architecture()

    def _build_architecture(self):
        layers = []

        layers += [ConvBlock(self.in_channels, self.intermediate_channels[0], upscale=True, kernel_size=3, stride=1, padding=1)]

        for ins, outs in zip(self.intermediate_channels[:-1], self.intermediate_channels[1:]):
            layers += [ConvBlock(ins, outs, upscale=True, kernel_size=3, stride=1, padding=1)]

        layers += [ConvBlock(self.intermediate_channels[-1], self.out_channels, upscale=False, kernel_size=3, stride=1, padding=1)]

        return nn.Sequential(*layers)

    def forward(self, z: torch.Tensor):
        return self.generator(z)


In [32]:
class Critic(nn.Module):
    def __init__(self, in_channels: int, intermediate_channels: List[int]):
        """
        Critic architecture

        Args:
            in_channels (int): input channels
            intermediate_channels (List[int]): list containing number of intermediate channels in each layer of the generator
        """
        super().__init__()

        self.in_channels = in_channels
        self.intermediate_channels = intermediate_channels

        self.critic = self._build_architecture()

    def _build_architecture(self):
        layers = []

        layers += [ConvBlock(self.in_channels, self.intermediate_channels[0], upscale = False, kernel_size=3, stride=1, padding=1)]

        for ins, outs in zip(self.intermediate_channels[:-1], self.intermediate_channels[1:]):
            layers += [ConvBlock(ins, outs, upscale = False, kernel_size=4, stride=2, padding=1)]

        layers += [ConvBlock(self.intermediate_channels[-1], out_channels=1, upscale = False, kernel_size=3, stride=1, padding=1)]

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        return self.critic(x)

#### Dataset

In [33]:
class ImgTransform:
    def __init__(self):
        self.transforms = A.Compose([
            A.Resize(256, 256),
            A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ToTensorV2()
        ])

    def __call__(self, img):
        return self.transforms(image=np.array(img))["image"]

transform = ImgTransform()

In [34]:
dataset = torchvision.datasets.CelebA(download=True, root="../../datasets/celeb_a", transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)

#### Training models

In [35]:
layers_gen = [128, 128, 256, 256, 256, 512, 512, 512]
generator = Generator(100, layers_gen, 3)

layers_crt = [64, 128, 128, 256, 256]
critic = Critic(3, layers_crt)

In [36]:
# logger = WandbLogger("wgan")
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=100)
learner = Learner(generator, critic, 3e-4, critic_iters=5, gp_weight=10, img_save_path="sampled_imgs")
trainer.fit(learner, loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
----------------------------------------
0 | generator | Generator | 7.7 M 
1 | critic    | Critic    | 2.0 M 
----------------------------------------
9.6 M     Trainable params
0         Non-trainable params
9.6 M     Total params
38.518    Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
