<a href="https://colab.research.google.com/github/Tien-Cheng/dele-generative-adversarial-networks/blob/main/DELE_CA2_Part_I_Controllable_%26_Conditional_GANS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Problem Statement

> How can I create a GAN model that allows me to control the generated output, so as to enable me to express my creativity>

## What I Want to Do

- Create a Conditional GAN that:
  - Achieves a good FID and KID score (i.e. good image quality)
  - Is able to successfully generate images based on class
  - Potentially allows me to control the features of the image (e.g. color)

## How will I achieve it?

- I will try various network architectures

  - DCGAN
  - SNGAN
  - StyleGAN2

- To improve the quality of images, I apply the following techniques
  - FreezeD (Transfer Learning)
  - Truncation Trick
  - Differentiable Augmentation
  - Hinge Loss
  - Two Time Update Rule
  - Skip Connections from Noise to Generator (which could allow for GANSpace edits)


# Google Colab Setup

Due to the heavy computational requirements for training a GAN, Google Colab is used as a compute platform. The dataset and utilty functions have been stored on a GitHub repository, and so we need to clone the Git repo. In addition, some additional libraries have to be installed.


In [None]:
! mkdir -p /root/.ssh
with open("/root/.ssh/id_rsa", mode="w") as fp: # Repository Deploy Key
    fp.write("""
-----BEGIN RSA PRIVATE KEY-----
MIIJKgIBAAKCAgEAt4IzVm1w9r7xKuS+zYuBb6UNB2NFHQBaRbhih7+HBK+yNSbz
lGu9P/sWbFarsY68zKCISb8+K+hulP0ay9OdnCLat9z96eOZ0gX6Iqsh6+szfNvm
8m1SJeXc7C6UGmyNIpr33TUpf56y28UFa656rIjff1w20SRKjL2rgu8rx+lxiASL
+hXiZi2t1PA6oLD3puD9TOwN85Ct5mutmTjBYQKbmk04Sp8jE9DqloJPpkCJHVh6
cJ0bzyVCx4njzdoQeWwPtVa67wyHIXDqH1xZBAkAqt2WAx4npLGgTotPSUaFkDLw
co6SnpOLx8ZGrpggX1k2Oh7FOH75nZXHKjrfWtX5pbkw8bxYmNLTErq/t19ULBQi
dVyv406ARf1rDOUFoMfsOsc1pd/wf66mcZUn3s3ogI6it5zGrpzCpfrlxHgJQ/Uh
gvyWA88J6BRVgMA5cUS4gb/OEmuHvdM9CRY6HELAS35tS85zcRQXipYqngx/dgaV
GhHHIWZh1bOkbn1dRV3xQau6KxYOyLI/i+eBFJA3jxvDKlRMVfy1DMSn0DffFlFt
ApOAqOccYo124vUthsiWJ9qExdCJ36+tAHpFelfyjAygJMWCZhaWprxvY9VG/lcR
6vjNtjaUySWx3l4GTadmbSwARK9gl5Xgbp0qojx1FMpsFcnCsz3y8yB7TYECAwEA
AQKCAgBVmHSrxqafYVcKg+H/7Cd21QzrukEdkvGIfcXvvcWTyQQdyMprG4oN0ueV
pyO00Xh9FhAcHgk439Tcx+Z81ns4vgU5J+qD8zbngQQ4sYxEB9RfVA84WwerR7mx
rNRGMwXt80zUMJznuzWATzkFDkCIQ9vEA1ZKXVwso7fhff/04o2jPUOxZg3RTVM8
9MTT+Ve6zk04WQ704jJLPUSfKJsCzf2YjpZIMExjTNpvU98lFAsg1gleh9nV2HJ6
snXAqgtvJ5l4IzlUkYpibdG2yRN4T16xVGRJlgI1zuiQWmikLDHWnfwL4za+ouHb
UD/d5nWLJAioOXwSqx9xgtCAgS92211ydmKWmsdThHdRDN63ncFPNg7I3ZVdRKOK
bdAMev6sP3CXnWO7aRO0sGT4wg7tGbnyE53I4RJXck1aZZGSyOvswucCxEQsnbSP
Hr78/kc+5+DJX8pbc0NuLAspkUVoSU75Idrv5B+2UQSb1ZXspfp923s0voRw0sJ8
ydOg1n173QOwKnAE++tXQrdPZyU2cHkuvg426snCjlpbogfmmlj8cGGb8EOZcdJv
I3r+w2V+9bC2Z7O4OJhe2HlwM0N6F+KBnyJHsxbdP08OqZYzVMiDmg5Rfbkwf8W9
arkt9+pAWSix0nkp7qNgD+qkjfrtOxIX//mFbIBWhhq1gSby5QKCAQEA63P0DsuC
APhI0/GeSFJ8FXYtVFrX8/DJjOH441VEaQQMlAub6KnlhsOo0nS10A5GqV9/EeZ7
ss2w22JwIh+Wk6PtxPU7lMEQRy1eV0GUrQdrE8StlLs36zszMswMafnm4yA4ki6g
Lr3BR6Ps38TzLba6mctOFt9T8wV6+/YB2PFi3r5tmX0zYNQi6mbTFnlDv/dzaFOG
fT823OuOzuwVvgu90651PutVfPrNhUTTykuGyDee5kkn21HQeguoLDWEIfeV+ujH
l/AAT7rpNmxtSl0m+iwYsbKDbX28DCGnWXgeMuFgMUblRvKumChbno81JkJIOOwV
+DcWryqTLp17VwKCAQEAx4XN5+PGMg9na7vZW6zU8r92RFKsjjhueygtpxdDUjdV
MSYPu4mgO9ab62nf9LQJE2JCN6e8zHWEAooaIt83TzCa6SaYbTEnzin2M9gSYtW6
MQ429zq49MOdZfwMfRgfnFAnA8KDIfYqqcPcmnQWHWhNGXyS3CccYw+2+gmRHLoM
ohcoVZne6VuMqkEzf8SDaR8k9gwVjqxVqpQN8p81PE00a02k+QDwyNsrcnM19plB
kntb9FLuqQf+lmDhe0/9fDqcjIEDz4eonLlFaTrFegGybTQcKD+3uyC0k9njUFwJ
Y77I3kJiaoDuXXVxWETS3KvaE2rmjXAEcrN5rkfO5wKCAQEAl+41kQputBOCYwjp
Ov/Gw86DB4irCuTYGYmDIaZWw3DycOFg1Gw1CJXerRbUbxGXNRnDFBjmvwUNVzMY
6lv5vQEtn0cjECTYTSWQV7ugpVpBFPt3ip6YQbjsm52hcQzpmKuk9WcSw7Z8Lq8v
XWFoDZp4pF7U39tx/0INDuK6ZHO2ecblUALDEXsxoJGDKmBLgGa7WJl1EgKlcz6o
4wriKMTI0/wh+dy/SCtKTPGRvFqp+S4y4aRZDKOpY+d7uDM8NPLfG43zpS4f9VLF
w/GJQFAFo66qrJdlSVS18BoTM59X1Tsq6AE4V2SnltWL8S+1ex+QHPLyZj2d7KAL
YywJdwKCAQEApWUG3j6T0nWwfr82nGc2E5ChgluiTTb8Zr1Ustl25hWWWmq5yfV5
TYFGuSyICTqg91+Rkr9Ko5aa+tvudI/jMpMRJ0rmOkXwQFfKjwmDnEid0wJ8kA8u
uT/bH2qEE8LGmXZcESLSP3nnvdjt619l4bTPjNwWhccqIfgp7zW1BEI6LLfTqLon
7fwFLDFmdni5ko/NvOUhjabQUNnwgfp2T+mUFYtEwWGFOItuha55wlUi5UG7ZVrG
GnrVEWV4JReXAr83fMWKGiPToy92GZgtkUkM1rfGy5qePNIMvy903u2cnwHNU2lm
WfFNJ04uykQrI+CVo1kPi5mbJlYe/VjrawKCAQEA5Pmjb8/MdAUEkb3zAD7GJIKC
HnUAA4mwk8xVdsGN6xvUL8RYgi+VjSKvzNsUln5sPXdtZbP//gQOF7KgLPFFe+mf
Xok7fGSTQ1DgVWEErFynAYxu+Uu4xtjRbPyCXjyoHianXkn3QDf1ggpF+y2R0Ivu
oyxsDvMArFalbmK4q/+Q6/z/DtnirfjUnxiYEPEBZtP3Gz74KQK/AhForVlCiSz6
MbDp30cxPy/8/pimJ9xUR6re9Xuw/EFWp0ifHXv6IGNOd8UQGejyI82KnJZPNTde
tHO70d3zFdhrpJO63Elrw6c9bxeZrcJTT1e3wFpX2z1aE4dybdNqrI/IbzcdVA==
-----END RSA PRIVATE KEY-----
""")
! ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
! chmod go-rwx /root/.ssh/id_rsa
! git clone git@github.com:Tien-Cheng/dele-generative-adversarial-networks.git
%cd /content/dele-generative-adversarial-networks

# github.com:22 SSH-2.0-babeld-e47cd09f
chmod: cannot access '/root/.ssh/rsa': No such file or directory
Cloning into 'dele-generative-adversarial-networks'...
remote: Enumerating objects: 26, done.[K
remote: Counting objects: 100% (26/26), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 26 (delta 9), reused 20 (delta 6), pack-reused 0[K
Receiving objects: 100% (26/26), 7.52 KiB | 7.52 MiB/s, done.
Resolving deltas: 100% (9/9), done.
/content/dele-generative-adversarial-networks


In [None]:
%%capture
%pip install -U torchmetrics[image] wandb torch-summary pytorch-lightning

# Setup


### Imports


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import wandb
import torchvision.utils as vutils
import gc

from typing import *
from collections import OrderedDict
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
from torch.optim.lr_scheduler import *
from torch.optim import Adam
from data.dataset import CIFAR10DataModule
from torchvision import transforms as T
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelSummary, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from utils.DiffAugment_pytorch import (
    DiffAugment,
)  # Make use of the official implementation of DiffAugment
from utils.visualize import visualize
from utils.layers import (
    ResidualBlockDiscriminator,
    ResidualBlockGenerator,
    NormalizeInverse,
    ConditionalBatchNorm2d,
    ResidualBlockDiscriminatorHead,
)
from utils.ema import EMA


### Basic Hyperparameters


In [3]:
# @title Basic Hyperparameters { run: "auto" }
DATA_DIR = "./data"  # @param {type:"string"}
BATCH_SIZE = 50  # @param {type:"integer"}
NUM_WORKERS = 2  # @param {type:"integer"}


## Data Ingestion and Preprocessing


In [4]:
preprocessing = [T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
basic_aug = [
    T.RandomHorizontalFlip()
]

In [5]:
dm = CIFAR10DataModule(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    transforms=preprocessing,
    augments=basic_aug
)


# Building a Conditional LSGAN


## Generator


In [6]:
class LSGANGenerator(nn.Module):
    def __init__(
        self, latent_dim: int = 100, num_filters: int = 64, num_classes: int = 10
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_filters = num_filters  # Base Number of Filters in Generator Blocks
        self.num_classes = num_classes
        self.label_embedding = nn.Embedding(num_classes, latent_dim)
        self.latent = nn.Linear(latent_dim, latent_dim)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, num_filters * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(num_filters * 4),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(num_filters * 4, num_filters * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_filters * 2),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(num_filters * 2, num_filters, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_filters),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(num_filters, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, x: torch.Tensor, y: list = None):
        conditional_inputs = torch.mul(x, self.label_embedding(y))
        conditional_inputs = self.latent(conditional_inputs)
        conditional_inputs = conditional_inputs.view(
            conditional_inputs.shape[0], self.latent_dim, 1, 1
        )
        return self.main(conditional_inputs)


## Discriminator


In [7]:
class LSGANDiscriminator(nn.Module):
    def __init__(self, num_filters: int = 64, num_classes: int = 10):
        super().__init__()
        self.num_filters = num_filters
        self.num_classes = num_classes
        self.label_embedding = nn.Embedding(num_classes, 32 * 32)
        self.main = nn.Sequential(
            nn.Conv2d(3, num_filters, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters, num_filters * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_filters * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters * 2, num_filters * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_filters * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.2),
            nn.Conv2d(
                num_filters * 4, 1, 4, 1, 0, bias=False
            ),  # Apply a 4x4 Convolution to a 4x4 input, resulting in a 1x1 output
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor, y: list = None):
        labels = self.label_embedding(y)
        labels = labels.view(labels.shape[0], 1, 32, 32)
        conditional_inputs = torch.mul(x, labels)  # Concat as Extra Channel
        return self.main(conditional_inputs)


## Conditional LSGAN


In [13]:
class ConditionalLSGAN(LightningModule):
    def __init__(
        self,
        latent_dim: int = 100,
        num_classes: int = 10,
        g_lr: float = 0.0002,
        d_lr: float = 0.0002,
        adam_betas: Tuple[float, float] = (0.5, 0.999),
        batch_size: int = 64,
        validation_size: int = 10000,
        d_steps: int = 1,
        **kwargs,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.g_lr = g_lr
        self.d_lr = d_lr
        self.betas = adam_betas
        self.batch_size = batch_size
        self.validation_size = validation_size  # For FID, minimum number of samples is 10K for accurate result
        self.d_steps = d_steps  # Number of Discriminator steps per Generator Step
        self.save_hyperparameters()

        self.G = LSGANGenerator(
            latent_dim=self.latent_dim,
            num_classes=num_classes,
        )

        self.D = LSGANDiscriminator(num_classes=num_classes)

        self.G.apply(self._weights_init)
        self.D.apply(self._weights_init)

        self.adversarial_loss = nn.MSELoss()  # Least Squares Loss

        self.fid = FrechetInceptionDistance()
        self.inception_score = InceptionScore()
        self.unnormalize = NormalizeInverse((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    @staticmethod
    def _weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
        elif classname.find("BatchNorm") != -1:
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

    def forward(self, z, labels):
        return self.G(z, labels)

    # Alternating schedule for optimizer steps (i.e.: GANs)
    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure,
        on_tpu,
        using_native_amp,
        using_lbfgs,
    ):
        # update discriminator opt every step
        if optimizer_idx == 1:
            optimizer.step(closure=optimizer_closure)

        # update generator opt every 4 steps
        if optimizer_idx == 0:
            if (batch_idx + 1) % self.d_steps == 0:
                optimizer.step(closure=optimizer_closure)
            else:
                # call the closure by itself to run `training_step` + `backward` without an optimizer step
                optimizer_closure()

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, labels = batch  # Get real images and corresponding labels

        # Generate Noise Vector z
        z = torch.randn(imgs.shape[0], self.latent_dim)
        z = z.type_as(imgs)  # Ensure z runs on the same device as the images
        self.fake_labels = torch.LongTensor(
            torch.randint(0, self.num_classes, (imgs.shape[0],))
        ).to(self.device)
        # Train Generator
        if optimizer_idx == 0:
            # Generate Images
            self.fakes = self.forward(z, self.fake_labels)

            # Classify Generated Images with Discriminator
            fake_preds = torch.squeeze(self.D(self.fakes, self.fake_labels))

            # We want to penalize the Generator if the Discriminator predicts it as fake
            # Hence, set the target as a 1's vector
            target = torch.ones(imgs.shape[0]).type_as(imgs)

            g_loss = self.adversarial_loss(fake_preds, target)

            self.log(
                "train_gen_loss",
                g_loss,
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )  # Log Generator Loss
            tqdm_dict = {
                "g_loss": g_loss,
            }
            output = OrderedDict(
                {"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output

        # Train Discriminator
        if optimizer_idx == 1:
            # Train on Real Data
            real_preds = torch.squeeze(self.D(imgs, labels))
            target = torch.ones(imgs.shape[0]).type_as(imgs)
            d_real_loss = self.adversarial_loss(real_preds, target)

            # Train on Generated Images
            self.fakes = self.forward(z, self.fake_labels)
            target = torch.zeros(imgs.shape[0]).type_as(imgs)
            fake_preds = torch.squeeze(self.D(self.fakes, self.fake_labels))
            d_fake_loss = self.adversarial_loss(fake_preds, target)
            d_loss = (d_real_loss + d_fake_loss) / 2

            self.log(
                "train_discriminator_loss",
                d_loss,
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )
            tqdm_dict = {
                "d_loss": d_loss,
            }
            output = OrderedDict(
                {"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output

    def training_epoch_end(self, outputs):
        # Log Sampled Images
        sample_imgs = self.unnormalize(self.fakes[:64]).cpu().detach()
        sample_labels = self.fake_labels[:64].cpu().detach()
        num_rows = int(np.floor(np.sqrt(len(sample_imgs))))
        fig = visualize(sample_imgs, sample_labels, grid_shape=(num_rows, num_rows))
        self.logger.log_image(key="generated_images", images=[fig])
        plt.close(fig)
        del sample_imgs
        del sample_labels
        del fig
        gc.collect()

    def validation_step(self, batch, batch_idx):
        """
        Update FID for real Data
        """
        real, _ = batch
        # Calculate Metrics for Real Data
        torch.cuda.empty_cache()
        real = self.unnormalize(real)
        real = (real * 255).type(torch.uint8).to(self.device)
        self.fid.update(
            real, real=True
        )  # Only log on the end of the epoch as it does not make sense
        validation_size = self.validation_size // self.batch_size
        validation_z = torch.randn(validation_size, self.latent_dim)
        validation_labels = torch.LongTensor(
            torch.randint(0, self.num_classes, (validation_size,))
        ).to(self.device)
        # Generate Images. A minimum of 10K is recommended for accurate FID
        validation_z = validation_z.type_as(
            self.G.latent.weight
        )  # Ensure z runs on the same device as the images
        fakes = self.forward(
            validation_z, validation_labels
        )  # Display only the first 9 images
        fakes = (self.unnormalize(fakes) * 255).type(torch.uint8).to(self.device)
        self.inception_score.update(fakes)
        self.fid.update(
            fakes,
            real=False,
        )
       

    def validation_epoch_end(self, outputs):
        self.log(
            "IS",
            self.inception_score.compute()[0],  # Compute Mean IS
            prog_bar=True,
        )
        self.log("FID", self.fid, prog_bar=True)
        self.fid.reset()
        self.inception_score.reset()

    def configure_optimizers(self):
        """Define the optimizers and schedulers for PyTorch Lightning

        :return: A tuple of two lists - a list of optimizers and a list of learning rate schedulers
        :rtype: Tuple[List, List]
        """
        opt_G = Adam(
            self.G.parameters(), lr=self.g_lr, betas=self.betas
        )  # optimizer_idx = 0
        opt_D = Adam(
            self.D.parameters(), lr=self.d_lr, betas=self.betas
        )  # optimizer_idx = 1
        return [opt_G, opt_D], []


## Trainer


In [9]:
wandb.login()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtiencheng[0m (use `wandb login --relogin` to force relogin)


True

In [14]:
# seed_everything(42)
# wandb_logger = WandbLogger(project="DELE_CA2_GAN", log_model="all")
# model = ConditionalLSGAN(g_lr=0.0001, d_lr=0.0003, batch_size=BATCH_SIZE)
# trainer = Trainer(
#     check_val_every_n_epoch=5,
#     logger=wandb_logger,
#     max_epochs=1000,
#     callbacks=[
#         ModelCheckpoint(monitor="FID", mode="min", every_n_epochs=1),
#         ModelSummary(3),
#     ],
#     gpus=1,
# )


Global seed set to 42
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [15]:
# trainer.fit(model, dm)


  rank_zero_deprecation(
  rank_zero_warn(

  | Name              | Type               | Params
---------------------------------------------------------
0 | G                 | LSGANGenerator     | 3.7 M 
1 | G.label_embedding | Embedding          | 100   
2 | G.main            | Sequential         | 3.7 M 
3 | D                 | LSGANDiscriminator | 2.8 M 
4 | D.label_embedding | Embedding          | 100   
5 | D.main            | Sequential         | 2.8 M 
6 | adversarial_loss  | MSELoss            | 0     
7 | fid               | FID                | 23.9 M
8 | fid.inception     | NoTrainInceptionV3 | 23.9 M
9 | unnormalize       | NormalizeInverse   | 0     
---------------------------------------------------------
6.4 M     Trainable params
23.9 M    Non-trainable params
30.3 M    Total params
121.101   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


In [None]:
# wandb.finish()


## Model Improvement


### Improving the Generator and Discriminator


### Improving the Generator

Introducing

- Spectral Normalization
- Upsampling instead of Transposed Convolutions
- Residual Connections


In [None]:
class ResNetGenerator(nn.Module):
    def __init__(
        self,
        latent_dim: int = 128,
        num_filters: int = 256,
        num_classes: int = 10,
        activation: callable = F.relu,
        use_spectral_norm: bool = True,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_filters = num_filters
        self.num_classes = num_classes
        self.use_spectral_norm = use_spectral_norm
        self.embed = spectral_norm(nn.Embedding(num_classes, num_filters))
        self.latent = nn.Linear(latent_dim, num_filters * 4 * 4)
        self.blocks = nn.ModuleList(
            [
                ResidualBlockGenerator(
                    in_ch=num_filters,
                    out_ch=num_filters,
                    activation=activation,
                    upsample=True,
                    num_classes=num_classes,
                    use_spectral_norm=use_spectral_norm,
                )
                for _ in range(3)  # Input: 4x4 | -> 8x8 -> 16x16 -> 32x32
            ]
        )
        self.output = [
            nn.BatchNorm2d(num_filters),
            nn.ReLU(),
            nn.Conv2d(num_filters, 3, kernel_size=3, padding=1, stride=1),
            nn.Tanh(),
        ]
        if use_spectral_norm:
            self.output[2] = spectral_norm(self.output[2])

        self.output = nn.Sequential(*self.output)

    def forward(self, x, y):
        y_embed = self.embed(y)
        h = self.latent(x)
        h = h.view(h.shape[0], self.num_filters, 4, 4)
        for block in self.blocks:
            h = block(h, y_embed)
        return self.output(h)


In [None]:
class ResNetDiscriminator(nn.Module):
    def __init__(
        self,
        num_filters: int = 128,
        num_classes: int = 10,
        activation: callable = F.relu,
    ):
        """Implementation inspired by Projection Discriminator: https://github.com/pfnet-research/sngan_projection/blob/master/dis_models/snresnet_32.py

        :param num_filters: [description], defaults to 128
        :type num_filters: int, optional
        :param num_classes: [description], defaults to 10
        :type num_classes: int, optional
        :param activation: [description], defaults to F.relu
        :type activation: callable, optional
        """
        super().__init__()
        self.num_filters = num_filters
        self.num_classes = num_classes
        self.activation = activation
        self.blocks = nn.Sequential(
            ResidualBlockDiscriminatorHead(3, num_filters, activation=activation),
            ResidualBlockDiscriminator(
                num_filters, num_filters, activation=activation, downsample=True
            ),
            ResidualBlockDiscriminator(
                num_filters, num_filters, activation=activation, downsample=False
            ),
            ResidualBlockDiscriminator(
                num_filters, num_filters, activation=activation, downsample=False
            ),
        )
        self.classifier = spectral_norm(nn.Linear(num_filters, 1, bias=False))
        self.embed = spectral_norm(nn.Embedding(num_classes, num_filters))

    def forward(self, x, y):
        h = self.blocks(x)
        h = self.activation(h)
        h = h.mean([2, 3])  # Global Avg Pooling
        out = self.classifier(h)
        out = out + torch.sum(self.embed(y) * h, axis=1, keepdims=True)
        return out


### Introducing Differentiable Augmentations to Reduce Overfitting of Discriminator


In [None]:
class ConditionalLSGAN_Aug(ConditionalLSGAN):
    def __init__(
        self,
        latent_dim: int = 128,
        num_classes: int = 10,
        g_lr: float = 0.0002,
        d_lr: float = 0.0002,
        adam_betas: Tuple[float, float] = (0.0, 0.9),
        batch_size: int = 64,
        validation_size: int = 10000,
        policy: str = "color,translation,cutout",
        d_steps: int = 1,
        **kwargs,
    ):
        super().__init__(
            latent_dim,
            num_classes,
            g_lr,
            d_lr,
            adam_betas,
            batch_size,
            validation_size,
            d_steps ** kwargs,
        )

        self.policy = policy
        self.save_hyperparameters("policy")

        self.G = ResNetGenerator(
            latent_dim=self.latent_dim,
            num_classes=num_classes,
        )

        self.D = ResNetDiscriminator(num_classes=num_classes)

        self.G.apply(self._weights_init)
        self.D.apply(self._weights_init)

    @staticmethod
    def _weights_init(m):
        if (
            isinstance(m, nn.Linear)
            or isinstance(m, nn.Conv2d)
            or isinstance(m, nn.Embedding)
        ):
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, labels = batch  # Get real images and corresponding labels
        imgs = DiffAugment(imgs, policy=self.policy)
        # Generate Noise Vector z
        z = torch.randn(imgs.shape[0], self.latent_dim)
        z = z.type_as(imgs)  # Ensure z runs on the same device as the images
        self.fake_labels = torch.LongTensor(
            torch.randint(0, self.num_classes, (imgs.shape[0],))
        ).to(self.device)
        # Train Generator
        if optimizer_idx == 0:
            # Generate Images
            self.fakes = self.forward(z, self.fake_labels)

            # Classify Generated Images with Discriminator
            fake_preds = torch.squeeze(
                self.D(DiffAugment(self.fakes, policy=self.policy), self.fake_labels)
            )

            # We want to penalize the Generator if the Discriminator predicts it as fake
            # Hence, set the target as a 1's vector
            target = torch.ones(imgs.shape[0]).type_as(imgs)

            g_loss = self.adversarial_loss(fake_preds, target)

            self.log(
                "train_gen_loss",
                g_loss.detach(),
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )  # Log Generator Loss
            tqdm_dict = {
                "g_loss": g_loss.detach(),
            }
            output = OrderedDict(
                {"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output

        # Train Discriminator
        if optimizer_idx == 1:
            # Train on Real Data
            real_preds = torch.squeeze(self.D(imgs, labels))
            target = torch.ones(imgs.shape[0]).type_as(imgs)
            d_real_loss = self.adversarial_loss(real_preds, target)

            # Train on Generated Images
            self.fakes = self.forward(z, self.fake_labels)
            target = torch.zeros(imgs.shape[0]).type_as(imgs)
            fake_preds = torch.squeeze(
                self.D(DiffAugment(self.fakes, policy=self.policy), self.fake_labels)
            )
            d_fake_loss = self.adversarial_loss(fake_preds, target)
            d_loss = (d_real_loss + d_fake_loss) / 2

            self.log(
                "train_discriminator_loss",
                d_loss.detach(),
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )
            tqdm_dict = {
                "d_loss": d_loss.detach(),
            }
            output = OrderedDict(
                {"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output


- DiffAugment was too powerful, causing model to collapse


In [None]:
# seed_everything(42)
# wandb_logger = WandbLogger(project="DELE_CA2_GAN", log_model="all")
# model = ConditionalLSGAN_Aug(g_lr=0.0001, d_lr=0.0004, batch_size=BATCH_SIZE)
# trainer = Trainer(
#     check_val_every_n_epoch=5,
#     logger=wandb_logger,
#     max_epochs=1000,
#     callbacks=[
#         # ModelCheckpoint(monitor="FID", mode="min", every_n_epochs=1),
#         ModelSummary(3),
#     ],
#     gpus=1,
# )


In [None]:
# trainer.fit(model, dm)


In [None]:
# wandb.finish()


### Hinge Loss


In [None]:
class ConditionalSNGAN_Aug(ConditionalLSGAN):
    def __init__(
        self,
        latent_dim: int = 128,
        num_classes: int = 10,
        g_lr: float = 0.0002,
        d_lr: float = 0.0002,
        adam_betas: Tuple[float, float] = (0.0, 0.9),
        batch_size: int = 64,
        validation_size: int = 10000,
        policy: str = "color,translation,cutout",
        d_steps: int = 1,
        w_init_policy: str = "ortho",
        **kwargs,
    ):
        super().__init__(
            latent_dim,
            num_classes,
            g_lr,
            d_lr,
            adam_betas,
            batch_size,
            validation_size,
            d_steps,
            **kwargs,
        )

        self.policy = policy
        self.w_init_policy = w_init_policy

        self.save_hyperparameters("policy", "w_init_policy")

        self.G = ResNetGenerator(
            latent_dim=self.latent_dim,
            num_classes=num_classes,
        )

        self.D = ResNetDiscriminator(num_classes=num_classes)
        if w_init_policy == "normal":
            init_func = self._normal_weights_init
        elif w_init_policy == "ortho":
            init_func = self._ortho_weights_init
        else:
            raise ValueError("Unknown Weight Init Policy")
        self.G.apply(init_func)
        self.D.apply(init_func)

    @staticmethod
    def _normal_weights_init(m):
        if (
            isinstance(m, nn.Linear)
            or isinstance(m, nn.Conv2d)
            or isinstance(m, nn.Embedding)
        ):
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

    @staticmethod
    def _ortho_weights_init(m):
        if (
            isinstance(m, nn.Linear)
            or isinstance(m, nn.Conv2d)
            or isinstance(m, nn.Embedding)
        ):
            torch.nn.init.orthogonal_(m.weight)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, labels = batch  # Get real images and corresponding labels
        imgs = DiffAugment(imgs, policy=self.policy)
        # Generate Noise Vector z
        z = torch.randn(imgs.shape[0], self.latent_dim)
        z = z.type_as(imgs)  # Ensure z runs on the same device as the images
        self.fake_labels = torch.LongTensor(
            torch.randint(0, self.num_classes, (imgs.shape[0],))
        ).to(self.device)
        # Train Generator
        if optimizer_idx == 0:
            # Generate Images
            self.fakes = self.forward(z, self.fake_labels)

            # Classify Generated Images with Discriminator
            fake_preds = torch.squeeze(
                self.D(DiffAugment(self.fakes, policy=self.policy), self.fake_labels)
            )

            g_loss = -fake_preds.mean()

            self.log(
                "train_gen_loss",
                g_loss,
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )  # Log Generator Loss
            tqdm_dict = {
                "g_loss": g_loss,
            }
            output = OrderedDict(
                {"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output

        # Train Discriminator
        if optimizer_idx == 1:
            # Train on Real Data
            real_preds = torch.squeeze(self.D(imgs, labels))
            d_real_loss = F.relu(1.0 - real_preds).mean()

            # Train on Generated Images
            self.fakes = self.forward(z, self.fake_labels)
            fake_preds = torch.squeeze(
                self.D(DiffAugment(self.fakes, policy=self.policy), self.fake_labels)
            )
            d_fake_loss = F.relu(1.0 + fake_preds).mean()
            d_loss = d_real_loss + d_fake_loss

            self.log(
                "train_discriminator_loss",
                d_loss,
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )
            tqdm_dict = {
                "d_loss": d_loss,
            }
            output = OrderedDict(
                {"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output


In [None]:
# seed_everything(42)
# wandb_logger = WandbLogger(project="DELE_CA2_GAN", log_model="all")
# model = ConditionalSNGAN_Aug(
#     g_lr=0.0001, d_lr=0.0004, batch_size=BATCH_SIZE, policy="", adam_betas=(0, 0.999)
# )
# trainer = Trainer(
#     check_val_every_n_epoch=5,
#     logger=wandb_logger,
#     max_epochs=1000,
#     callbacks=[
#         # ModelCheckpoint(monitor="FID", mode="min", every_n_epochs=1),
#         ModelSummary(3),
#     ],
#     gpus=1,
# )


In [None]:
# trainer.fit(model, dm)


In [None]:
# wandb.finish()


### Exponential Moving Average of Model Parameters


In [None]:
seed_everything(42)
wandb_logger = WandbLogger(project="DELE_CA2_GAN", log_model="all")
model = ConditionalSNGAN_Aug(
    g_lr=0.0002,
    d_lr=0.0002,
    batch_size=BATCH_SIZE,
    policy="translation",
    adam_betas=(0.0, 0.999),
    d_steps=4,
)
trainer = Trainer(
    check_val_every_n_epoch=5,
    logger=wandb_logger,
    max_epochs=1000,
    callbacks=[
        # ModelCheckpoint(monitor="FID", mode="min", every_n_epochs=1),
        ModelSummary(3),
        EMA(ema_device="cuda"),
    ],
    gpus=1,
)


In [None]:
trainer.fit(model, dm)


In [None]:
wandb.finish()
