<a href="https://colab.research.google.com/github/Tien-Cheng/dele-generative-adversarial-networks/blob/main/SNGAN%20-%204%20Discriminator%20Steps%20Per%20Generator%20Step.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Background

## Problem Statement

My objective for this project is to attempt to create a GAN model that is able to generate realistic images based on information that is provided to it.

Such a task generally falls under two categories:

1. Conditional Image Generation: Given a class (e.g. a car), generate an image of that class
2. Controllable Image Generation: Given some description of features of the image (e.g. The car should be red in color), generate an image with those features

In this project, I hope to tackle the first category (Conditional Image Generation), with the view of eventually tackling the second category if time and resources permit.

So, what would I consider to be a success? Well, I would preferably hope to be able to reliably generate images conditionally that look real enough, based on

- Eye Power 🕵️
- Metrics like FID (described later in the report)

By real enough, I mean that I hope to get results that are close enough to some of the top GAN models out there, with minimal artifacts.


# Google Colab Setup

Due to the heavy computational requirements for training a GAN, Google Colab is used as a compute platform. The dataset and utilty functions have been stored on a GitHub repository, and so we need to clone the Git repo. In addition, some additional libraries have to be installed.


In [1]:
! mkdir -p /root/.ssh
with open("/root/.ssh/id_rsa", mode="w") as fp: # Repository Deploy Key
    fp.write("""
-----BEGIN RSA PRIVATE KEY-----
MIIJKgIBAAKCAgEAt4IzVm1w9r7xKuS+zYuBb6UNB2NFHQBaRbhih7+HBK+yNSbz
lGu9P/sWbFarsY68zKCISb8+K+hulP0ay9OdnCLat9z96eOZ0gX6Iqsh6+szfNvm
8m1SJeXc7C6UGmyNIpr33TUpf56y28UFa656rIjff1w20SRKjL2rgu8rx+lxiASL
+hXiZi2t1PA6oLD3puD9TOwN85Ct5mutmTjBYQKbmk04Sp8jE9DqloJPpkCJHVh6
cJ0bzyVCx4njzdoQeWwPtVa67wyHIXDqH1xZBAkAqt2WAx4npLGgTotPSUaFkDLw
co6SnpOLx8ZGrpggX1k2Oh7FOH75nZXHKjrfWtX5pbkw8bxYmNLTErq/t19ULBQi
dVyv406ARf1rDOUFoMfsOsc1pd/wf66mcZUn3s3ogI6it5zGrpzCpfrlxHgJQ/Uh
gvyWA88J6BRVgMA5cUS4gb/OEmuHvdM9CRY6HELAS35tS85zcRQXipYqngx/dgaV
GhHHIWZh1bOkbn1dRV3xQau6KxYOyLI/i+eBFJA3jxvDKlRMVfy1DMSn0DffFlFt
ApOAqOccYo124vUthsiWJ9qExdCJ36+tAHpFelfyjAygJMWCZhaWprxvY9VG/lcR
6vjNtjaUySWx3l4GTadmbSwARK9gl5Xgbp0qojx1FMpsFcnCsz3y8yB7TYECAwEA
AQKCAgBVmHSrxqafYVcKg+H/7Cd21QzrukEdkvGIfcXvvcWTyQQdyMprG4oN0ueV
pyO00Xh9FhAcHgk439Tcx+Z81ns4vgU5J+qD8zbngQQ4sYxEB9RfVA84WwerR7mx
rNRGMwXt80zUMJznuzWATzkFDkCIQ9vEA1ZKXVwso7fhff/04o2jPUOxZg3RTVM8
9MTT+Ve6zk04WQ704jJLPUSfKJsCzf2YjpZIMExjTNpvU98lFAsg1gleh9nV2HJ6
snXAqgtvJ5l4IzlUkYpibdG2yRN4T16xVGRJlgI1zuiQWmikLDHWnfwL4za+ouHb
UD/d5nWLJAioOXwSqx9xgtCAgS92211ydmKWmsdThHdRDN63ncFPNg7I3ZVdRKOK
bdAMev6sP3CXnWO7aRO0sGT4wg7tGbnyE53I4RJXck1aZZGSyOvswucCxEQsnbSP
Hr78/kc+5+DJX8pbc0NuLAspkUVoSU75Idrv5B+2UQSb1ZXspfp923s0voRw0sJ8
ydOg1n173QOwKnAE++tXQrdPZyU2cHkuvg426snCjlpbogfmmlj8cGGb8EOZcdJv
I3r+w2V+9bC2Z7O4OJhe2HlwM0N6F+KBnyJHsxbdP08OqZYzVMiDmg5Rfbkwf8W9
arkt9+pAWSix0nkp7qNgD+qkjfrtOxIX//mFbIBWhhq1gSby5QKCAQEA63P0DsuC
APhI0/GeSFJ8FXYtVFrX8/DJjOH441VEaQQMlAub6KnlhsOo0nS10A5GqV9/EeZ7
ss2w22JwIh+Wk6PtxPU7lMEQRy1eV0GUrQdrE8StlLs36zszMswMafnm4yA4ki6g
Lr3BR6Ps38TzLba6mctOFt9T8wV6+/YB2PFi3r5tmX0zYNQi6mbTFnlDv/dzaFOG
fT823OuOzuwVvgu90651PutVfPrNhUTTykuGyDee5kkn21HQeguoLDWEIfeV+ujH
l/AAT7rpNmxtSl0m+iwYsbKDbX28DCGnWXgeMuFgMUblRvKumChbno81JkJIOOwV
+DcWryqTLp17VwKCAQEAx4XN5+PGMg9na7vZW6zU8r92RFKsjjhueygtpxdDUjdV
MSYPu4mgO9ab62nf9LQJE2JCN6e8zHWEAooaIt83TzCa6SaYbTEnzin2M9gSYtW6
MQ429zq49MOdZfwMfRgfnFAnA8KDIfYqqcPcmnQWHWhNGXyS3CccYw+2+gmRHLoM
ohcoVZne6VuMqkEzf8SDaR8k9gwVjqxVqpQN8p81PE00a02k+QDwyNsrcnM19plB
kntb9FLuqQf+lmDhe0/9fDqcjIEDz4eonLlFaTrFegGybTQcKD+3uyC0k9njUFwJ
Y77I3kJiaoDuXXVxWETS3KvaE2rmjXAEcrN5rkfO5wKCAQEAl+41kQputBOCYwjp
Ov/Gw86DB4irCuTYGYmDIaZWw3DycOFg1Gw1CJXerRbUbxGXNRnDFBjmvwUNVzMY
6lv5vQEtn0cjECTYTSWQV7ugpVpBFPt3ip6YQbjsm52hcQzpmKuk9WcSw7Z8Lq8v
XWFoDZp4pF7U39tx/0INDuK6ZHO2ecblUALDEXsxoJGDKmBLgGa7WJl1EgKlcz6o
4wriKMTI0/wh+dy/SCtKTPGRvFqp+S4y4aRZDKOpY+d7uDM8NPLfG43zpS4f9VLF
w/GJQFAFo66qrJdlSVS18BoTM59X1Tsq6AE4V2SnltWL8S+1ex+QHPLyZj2d7KAL
YywJdwKCAQEApWUG3j6T0nWwfr82nGc2E5ChgluiTTb8Zr1Ustl25hWWWmq5yfV5
TYFGuSyICTqg91+Rkr9Ko5aa+tvudI/jMpMRJ0rmOkXwQFfKjwmDnEid0wJ8kA8u
uT/bH2qEE8LGmXZcESLSP3nnvdjt619l4bTPjNwWhccqIfgp7zW1BEI6LLfTqLon
7fwFLDFmdni5ko/NvOUhjabQUNnwgfp2T+mUFYtEwWGFOItuha55wlUi5UG7ZVrG
GnrVEWV4JReXAr83fMWKGiPToy92GZgtkUkM1rfGy5qePNIMvy903u2cnwHNU2lm
WfFNJ04uykQrI+CVo1kPi5mbJlYe/VjrawKCAQEA5Pmjb8/MdAUEkb3zAD7GJIKC
HnUAA4mwk8xVdsGN6xvUL8RYgi+VjSKvzNsUln5sPXdtZbP//gQOF7KgLPFFe+mf
Xok7fGSTQ1DgVWEErFynAYxu+Uu4xtjRbPyCXjyoHianXkn3QDf1ggpF+y2R0Ivu
oyxsDvMArFalbmK4q/+Q6/z/DtnirfjUnxiYEPEBZtP3Gz74KQK/AhForVlCiSz6
MbDp30cxPy/8/pimJ9xUR6re9Xuw/EFWp0ifHXv6IGNOd8UQGejyI82KnJZPNTde
tHO70d3zFdhrpJO63Elrw6c9bxeZrcJTT1e3wFpX2z1aE4dybdNqrI/IbzcdVA==
-----END RSA PRIVATE KEY-----
""")
! ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
! chmod go-rwx /root/.ssh/id_rsa
! git clone git@github.com:Tien-Cheng/dele-generative-adversarial-networks.git
%cd /content/dele-generative-adversarial-networks

# github.com:22 SSH-2.0-babeld-00610741
Cloning into 'dele-generative-adversarial-networks'...
remote: Enumerating objects: 535, done.[K
remote: Counting objects: 100% (535/535), done.[K
remote: Compressing objects: 100% (358/358), done.[K
remote: Total 535 (delta 332), reused 360 (delta 165), pack-reused 0[K
Receiving objects: 100% (535/535), 518.73 KiB | 1.96 MiB/s, done.
Resolving deltas: 100% (332/332), done.
/content/dele-generative-adversarial-networks


In [2]:
%%capture
%pip install -U torch-fidelity wandb torch-summary pytorch-lightning

# Setup


### Imports


In [3]:
import gc
from collections import OrderedDict
from typing import *

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_fidelity
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.loggers import WandbLogger
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
from torch.optim import Adam
from torch.optim.lr_scheduler import *

from torchvision import transforms as T

import wandb
from data.dataset import CIFAR10DataModule
from utils.DiffAugment_pytorch import (
    DiffAugment,
)  # Make use of the official implementation of DiffAugment

from utils.layers import (
    NormalizeInverse,
    ResidualBlockDiscriminator,
    ResidualBlockDiscriminatorHead,
    ResidualBlockGenerator,
)
from utils.loss import R1, HingeGANLossGenerator, HingeGANLossDiscriminator
from utils.visualize import visualize


# Building an SNGAN with Projection Discriminator


## Generator


In [7]:
class ResNetGenerator(nn.Module):
    def __init__(
        self,
        latent_dim: int = 128,
        num_filters: int = 256,
        num_classes: int = 10,
        activation: callable = F.relu,
        use_spectral_norm: bool = True,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_filters = num_filters
        self.num_classes = num_classes
        self.use_spectral_norm = use_spectral_norm
        self.latent = nn.Linear(latent_dim, num_filters * 4 * 4)
        self.blocks = nn.ModuleList(
            [
                ResidualBlockGenerator(
                    in_ch=num_filters,
                    out_ch=num_filters,
                    activation=activation,
                    upsample=True,
                    num_classes=num_classes,
                    use_spectral_norm=use_spectral_norm,
                )
                for _ in range(3)  # Input: 4x4 | -> 8x8 -> 16x16 -> 32x32
            ]
        )
        self.output = [
            nn.BatchNorm2d(num_filters),
            nn.ReLU(),
            nn.Conv2d(num_filters, 3, kernel_size=3, padding=1, stride=1),
            nn.Tanh(),
        ]
        if use_spectral_norm:
            self.output[2] = spectral_norm(self.output[2])

        self.output = nn.Sequential(*self.output)

    def forward(self, x, y):
        h = self.latent(x)
        h = h.view(h.shape[0], self.num_filters, 4, 4)
        for block in self.blocks:
            h = block(h, y)
        output = self.output(h)
        if not self.training:
            output = 255 * (output.clamp(-1, 1) * 0.5 + 0.5)
            output = output.to(torch.uint8)
        return output


## Discriminator


In [8]:
class ResNetDiscriminator(nn.Module):
    def __init__(
        self,
        num_filters: int = 128,
        num_classes: int = 10,
        activation: callable = F.relu,
    ):
        """Implementation inspired by Projection Discriminator: https://github.com/pfnet-research/sngan_projection/blob/master/dis_models/snresnet_32.py

        :param num_filters: [description], defaults to 128
        :type num_filters: int, optional
        :param num_classes: [description], defaults to 10
        :type num_classes: int, optional
        :param activation: [description], defaults to F.relu
        :type activation: callable, optional
        """
        super().__init__()
        self.num_filters = num_filters
        self.num_classes = num_classes
        self.activation = activation
        self.blocks = nn.Sequential(
            ResidualBlockDiscriminatorHead(3, num_filters, activation=activation),
            ResidualBlockDiscriminator(
                num_filters, num_filters, activation=activation, downsample=True
            ),
            ResidualBlockDiscriminator(
                num_filters, num_filters, activation=activation, downsample=False
            ),
            ResidualBlockDiscriminator(
                num_filters, num_filters, activation=activation, downsample=False
            ),
        )
        self.classifier = spectral_norm(nn.Linear(num_filters, 1, bias=False))
        self.embed = spectral_norm(nn.Embedding(num_classes, num_filters))

    def forward(self, x, y):
        h = self.blocks(x)
        h = self.activation(h)
        h = h.mean([2, 3])  # Global Avg Pooling
        out = self.classifier(h)
        out = out + torch.sum(self.embed(y) * h, axis=1, keepdims=True)
        return out


## Conditional DCGAN


In [9]:
class SNGAN(LightningModule):
    def __init__(
        self,
        latent_dim: int = 128,
        num_classes: int = 10,
        g_lr: float = 0.0002,
        d_lr: float = 0.0002,
        adam_betas: Tuple[float, float] = (0.0, 0.999),
        d_steps: int = 1,
        r1_gamma: Optional[float] = 0.2,
        w_init_policy: str = "ortho",
        **kwargs,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.g_lr = g_lr
        self.d_lr = d_lr
        self.betas = adam_betas
        self.d_steps = d_steps  # Number of Discriminator steps per Generator Step
        self.r1_gamma = r1_gamma
        self.w_init_policy = w_init_policy
        self.save_hyperparameters()

        self.G = ResNetGenerator(
            latent_dim=latent_dim,
            num_classes=num_classes,
        )

        self.D = ResNetDiscriminator(num_classes=num_classes)

        if w_init_policy == "normal":
            init_func = self._normal_weights_init
        elif w_init_policy == "ortho":
            init_func = self._ortho_weights_init
        else:
            raise ValueError("Unknown Weight Init Policy")
        self.G.apply(init_func)
        self.D.apply(init_func)

        self.generator_loss = HingeGANLossGenerator()
        self.discriminator_loss = HingeGANLossDiscriminator()
        self.regularization_loss = R1(r1_gamma)

        self.viz_z = torch.randn(64, self.latent_dim)
        self.viz_labels = torch.LongTensor(torch.randint(0, self.num_classes, (64,)))

        self.unnormalize = NormalizeInverse((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    @staticmethod
    def _normal_weights_init(m):
        if (
            isinstance(m, nn.Linear)
            or isinstance(m, nn.Conv2d)
            or isinstance(m, nn.Embedding)
        ):
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

    @staticmethod
    def _ortho_weights_init(m):
        if (
            isinstance(m, nn.Linear)
            or isinstance(m, nn.Conv2d)
            or isinstance(m, nn.Embedding)
        ):
            torch.nn.init.orthogonal_(m.weight)

    def forward(self, z, labels):
        return self.G(z, labels)

    # Alternating schedule for optimizer steps (i.e.: GANs)
    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure,
        on_tpu,
        using_native_amp,
        using_lbfgs,
    ):
        # update discriminator opt every step
        if optimizer_idx == 1:
            optimizer.step(closure=optimizer_closure)

        # update generator opt every 4 steps
        if optimizer_idx == 0:
            if (batch_idx + 1) % self.d_steps == 0:
                optimizer.step(closure=optimizer_closure)
            else:
                # call the closure by itself to run `training_step` + `backward` without an optimizer step
                optimizer_closure()

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, labels = batch  # Get real images and corresponding labels

        # Generate Noise Vector z
        z = torch.randn(imgs.shape[0], self.latent_dim)
        z = z.type_as(imgs)  # Ensure z runs on the same device as the images
        self.fake_labels = torch.LongTensor(
            torch.randint(0, self.num_classes, (imgs.shape[0],))
        ).to(self.device)
        # Train Generator
        if optimizer_idx == 0:
            # Generate Images
            self.fakes = self.forward(z, self.fake_labels)

            # Classify Generated Images with Discriminator
            fake_preds = torch.squeeze(self.D(self.fakes, self.fake_labels))

            # We want to penalize the Generator if the Discriminator predicts it as fake
            # Hence, set the target as a 1's vector
            g_loss = self.generator_loss(fake_preds)

            self.log(
                "train_gen_loss",
                g_loss,
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )  # Log Generator Loss
            tqdm_dict = {
                "g_loss": g_loss,
            }
            output = OrderedDict(
                {"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output

        # Train Discriminator
        if optimizer_idx == 1:
            # Train on Real Data
            imgs.requires_grad = True
            real_preds = torch.squeeze(self.D(imgs, labels))
            # Train on Generated Images
            self.fakes = self.forward(z, self.fake_labels)
            fake_preds = torch.squeeze(self.D(self.fakes, self.fake_labels))
            d_loss = self.discriminator_loss(real_preds, fake_preds)
            if self.r1_gamma is not None:
                d_loss = d_loss + self.regularization_loss(real_preds, imgs)
            self.log(
                "train_discriminator_loss",
                d_loss,
                on_epoch=True,
                on_step=False,
                prog_bar=True,
            )
            tqdm_dict = {
                "d_loss": d_loss,
            }
            output = OrderedDict(
                {"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
            return output

    def training_epoch_end(self, outputs):
        # Log Sampled Images
        sample_imgs = self.unnormalize(self.fakes[:64].detach()).cpu()
        sample_labels = self.fake_labels[:64].detach().cpu()
        num_rows = int(np.floor(np.sqrt(len(sample_imgs))))
        fig = visualize(sample_imgs, sample_labels, grid_shape=(num_rows, num_rows))
        self.logger.log_image(key="generated_images", images=[fig])
        plt.close(fig)
        del sample_imgs
        del sample_labels
        del fig
        gc.collect()

    def validation_step(self, batch, batch_idx):
        pass

    def validation_epoch_end(self, outputs):
        sample_imgs = (
            self.forward(self.viz_z.to(self.device), self.viz_labels.to(self.device))
            .cpu()
            .detach()
        )
        fig = visualize(sample_imgs, self.viz_labels.cpu().detach(), grid_shape=(8, 8))
        metrics = torch_fidelity.calculate_metrics(
            input1=torch_fidelity.GenerativeModelModuleWrapper(
                self.G, self.latent_dim, "normal", self.num_classes
            ),
            input1_model_num_samples=10000,
            input2="cifar10-val",
            isc=True,
            fid=True,
            kid=True,
        )
        self.logger.log_image(key="Validation Images", images=[fig])
        plt.close(fig)
        del sample_imgs
        del fig
        gc.collect()
        self.log("FID", metrics["frechet_inception_distance"], prog_bar=True)
        self.log("IS", metrics["inception_score_mean"], prog_bar=True)
        self.log("KID", metrics["kernel_inception_distance_mean"], prog_bar=True)

    def configure_optimizers(self):
        """Define the optimizers and schedulers for PyTorch Lightning

        :return: A tuple of two lists - a list of optimizers and a list of learning rate schedulers
        :rtype: Tuple[List, List]
        """
        opt_G = Adam(
            self.G.parameters(), lr=self.g_lr, betas=self.betas
        )  # optimizer_idx = 0
        opt_D = Adam(
            self.D.parameters(), lr=self.d_lr, betas=self.betas
        )  # optimizer_idx = 1
        return [opt_G, opt_D], []


## Trainer


In [10]:
wandb.login()


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [11]:
seed_everything(42)
model = SNGAN(g_lr=0.0002, d_lr=0.0002, r1_gamma=0.2, d_steps=4)

Global seed set to 42
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
run = wandb.init()
artifact = run.use_artifact('tiencheng/DELE_CA2_GAN/model-19kepjip:v69', type='model')
artifact_dir = artifact.download()

In [None]:
model.load_from_checkpoint(
    artifact_dir
)

# GAN Evaluation