# RealnessGAN

## Imports

### wandb (interactive cell)

In [1]:
!pip -qqq install wandb pytorch-lightning torchmetrics

import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.login()

# API Key:
# d926baa25b6a14ffa4e5c30a6f3bbffbeca8fcf1

[34m[1mwandb[0m: Currently logged in as: [33maryangarg019[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Lightning

In [2]:
try:
  import lightning.pytorch as pl
except:
  print("[!] Couldn't find pytorch-lightning.\nInstalling it...\n")
  !pip install lightning
  import lightning.pytorch as pl

In [3]:
from lightning.pytorch.utilities.model_summary import ModelSummary

In [4]:
from pytorch_lightning import seed_everything

### standard imports

In [5]:
import os
import shutil
import pathlib

from PIL import Image
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets

In [6]:
import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.utils import make_grid

### Albumentations

In [7]:
try:
  import albumentations as A
  from albumentations.pytorch import ToTensorV2
except:
  print("[!] Couldn't find albumentations... installing it.")
  !pip install -U albumentations > /dev/null
  import albumentations as A
  from albumentations.pytorch import ToTensorV2

### Torchmetrics

In [8]:
try:
  import torchmetrics
except:
  print(f"[!] Torchmetrics couldn't be imported.\nInstalling...")
  !pip install torchmetrics > /dev/null
  import torchmetrics

### Custom Definitions

In [9]:
# Folder Utilities ----------------------------

## Create dir if it doesn't exist
def create_dir(dir_name):
  if not os.path.exists(f'/content/{dir_name}'):
    os.mkdir(f'/content/{dir_name}')

## Delete dir: checkpoints
def delete_dir(dir_name):
  if os.path.isdir(f'/content/{dir_name}'):
    shutil.rmtree(f'/content/{dir_name}')

# ---------------------------------------------

## Config File, Seeds & Devices

In [10]:
# Log this config file to wandb
CONFIG = dict(
    seed=42,
    DATA_ROOT = '/content/',
    checkpoint_path='/content/checkpoints/',
    G_LOSS_MODE = "EQ19_V2",
    SAVE_FOLDER = "resultsRealnessGAN",
    NUM_EPOCHS = 40,
    BATCH_SIZE = 64,
    D_LR = 2e-4,
    G_LR = 2e-4,
    BETA_1 = 0.5,
    BETA_2 = 0.999,

    # Model Hyperparameters
    LATENT_DIM = 20,
    HIDDEN_DIM = 256,
    IMAGE_DIM = 784,
    NUM_OUTCOMES = 10
    )

In [11]:
seed_everything(CONFIG['seed'])

INFO:lightning_fabric.utilities.seed:Global seed set to 42


42

In [12]:
# device = torch.device('cpu')
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

device

device(type='cuda')

## Transforms

In [13]:
# train_transform = A.Compose(
#     [
#         A.SmallestMaxSize(max_size=160),
#         A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
#         A.RandomCrop(height=128, width=128),
#         A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
#         A.RandomBrightnessContrast(p=0.5),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#         ToTensorV2(),
#     ]
# )
van_transform = T.Compose([T.ToTensor()])

## Dataset

In [14]:
train_dataset = datasets.MNIST('data', train=True, transform=van_transform, download=True)

## DataLoader

In [15]:
train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)

## Some More Settings

In [16]:
# Global Settings
NUM_EPOCHS = 40
BATCH_SIZE = 64
D_LR = 2e-4
G_LR = 2e-4
LR = G_LR

# Model Hyperparameters
LATENT_DIM = 20
HIDDEN_DIM = 256
IMAGE_DIM = 784
NUM_OUTCOMES = 10

## Model Arch

In [17]:
from torchvision import models
from torchsummary import summary

### Discriminator


In [18]:
class Discriminator(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_outcomes):

        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LeakyReLU(0.02),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.02),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.02),

            nn.Linear(hidden_dim, num_outcomes),
            nn.Softmax()
        )

    def forward(self, x):
        return self.model(x)

In [19]:
d = Discriminator(IMAGE_DIM, HIDDEN_DIM, NUM_OUTCOMES).to(device)
summary(d, (1,784))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [-1, 1, 256]         200,960
         LeakyReLU-2               [-1, 1, 256]               0
            Linear-3               [-1, 1, 256]          65,792
         LeakyReLU-4               [-1, 1, 256]               0
            Linear-5               [-1, 1, 256]          65,792
         LeakyReLU-6               [-1, 1, 256]               0
            Linear-7                [-1, 1, 10]           2,570
           Softmax-8                [-1, 1, 10]               0
Total params: 335,114
Trainable params: 335,114
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 1.28
Estimated Total Size (MB): 1.29
----------------------------------------------------------------


  input = module(input)


### Generator

In [20]:
class Generator(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, out_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [21]:
g = Generator(LATENT_DIM, HIDDEN_DIM, IMAGE_DIM).to(device)
summary(g, (1,LATENT_DIM))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                   [-1, 20]               0
            Linear-2                  [-1, 256]           5,376
       BatchNorm1d-3                  [-1, 256]             512
              ReLU-4                  [-1, 256]               0
            Linear-5                  [-1, 256]          65,792
       BatchNorm1d-6                  [-1, 256]             512
              ReLU-7                  [-1, 256]               0
            Linear-8                  [-1, 256]          65,792
       BatchNorm1d-9                  [-1, 256]             512
             ReLU-10                  [-1, 256]               0
           Linear-11                  [-1, 256]          65,792
      BatchNorm1d-12                  [-1, 256]             512
             ReLU-13                  [-1, 256]               0
           Linear-14                  [

## Utilities-2



In [22]:
def saveimg(image, savepath):
  image = image.transpose(1,2,0)
  plt.imsave(savepath, image)


def scale(tensor, mini=-1, maxi=1):
  return tensor * (maxi - mini) + mini


def scale_back(tensor, mini=-1, maxi=1):
  return (tensor-mini)/(maxi-mini)


def generate_latent(batch_size, latent_dim):
  return torch.empty(batch_size, latent_dim).uniform_(-1,1).to(device)


fixed_z = generate_latent(64, LATENT_DIM)
# print(fixed_z.shape)

## Lightning Recipe

In [23]:
from scipy.stats import skewnorm

In [24]:
class LIT_realnessGAN(pl.LightningModule):
  
  def __init__(self, 
               discriminator_model, 
               generator_model, 
               lr: float = 0.003,
               b1: float = 0.5,
               b2: float = 0.999, 
               disc_steps: int = 1):
    
    super().__init__()

    self.save_hyperparameters(ignore=[discriminator_model, 
                                      generator_model])
    
    self.automatic_optimization = False

    self.d = discriminator_model
    self.g = generator_model

    # Anchor 0 = Skewed normal to the left
    skew = skewnorm.rvs(-5, size=1000)
    count, bins = np.histogram(skew, NUM_OUTCOMES)
    anchor0 = count / sum(count)

    # Anchor 1 = Skewed normal to the right
    skew = skewnorm.rvs(5, size=1000)
    count, bins = np.histogram(skew, NUM_OUTCOMES)
    anchor1 = count / sum(count)

    self.A0 = torch.from_numpy(np.array(anchor0)).to(device).float()
    self.A1 = torch.from_numpy(np.array(anchor1)).to(device).float()

    # Print KLD between the anchors
    print("KLD(A0||A1): {}".format(self.KLD(self.A0.view(1, -1), self.A1)))
  
  def configure_optimizers(self):
    lr = self.hparams.lr
    b1 = self.hparams.b1
    b2 = self.hparams.b2

    optim_g = torch.optim.Adam(self.g.parameters(), lr=lr, betas=(b1,b2))
    optim_d = torch.optim.Adam(self.d.parameters(), lr=lr, betas=(b1,b2))

    return [optim_g, optim_d], []

  def KLD(self, P, Q):
    return torch.mean(torch.sum(P * (P/Q).log(), dim=1))

  def forward(self, z):
    return self.generator(z)

  def training_step(self, batch, batch_idx):
    real_images, _ = batch

    batch_size = real_images.shape[0]

    real_images = real_images.view(batch_size, -1).to(device)
    real_images = scale(real_images, -1, 1)

    optim_g, optim_d = self.optimizers()

    self.toggle_optimizer(optim_d)

    optim_d.zero_grad()

    # Discriminator Real Loss
    d_real_out = self.d(real_images)
    d_real_loss = self.KLD(d_real_out, self.A1)
    
    # Discriminator Fake Loss
    z = generate_latent(batch_size, LATENT_DIM)
    fake_images = self.g(z)
    d_fake_out = self.d(fake_images)
    d_fake_loss = self.KLD(self.A0, d_fake_out)

    # Total Discriminator Loss, Backprop, and Gradient Descent
    d_loss = d_real_loss + d_fake_loss
    d_loss = torch.autograd.Variable(d_loss, requires_grad = True)
    self.manual_backward(d_loss)

    optim_d.step()

    self.untoggle_optimizer(optim_d)

    self.toggle_optimizer(optim_g)
    # Generator Forward Prop
    optim_g.zero_grad()

    z = generate_latent(batch_size, LATENT_DIM)
    g_images = self.g(z)
    d_g_out = self.d(g_images)

    # Generator Loss
    # Line 12 in Paper=> Use: G_LOSS_MODE = "EQ19_V2"
    d_out = self.d(real_images)
    g_loss = -1. * self.KLD(self.A0, d_g_out) + self.KLD(d_out, d_g_out)    # -KL(A0 || D(G(z))) + KL(D(x) || D(G(z)))
    g_loss = torch.autograd.Variable(g_loss, requires_grad = True)
    # Total Generator Loss, Backprop and Gradient Descent
    self.manual_backward(g_loss)
    optim_g.step()

    # print(g_images.shape)
    self.logger.experiment.log({"Gen_Image (during training)":[wandb.Image(torch.reshape(g_images[0], (28,28)).cpu(), 
                                                                           caption="RealnessG Out")]})
    self.log_dict({"g_loss": g_loss.item(), "d_loss": d_loss.item()}, 
                  on_step=True, 
                  on_epoch=True, 
                  prog_bar=True, 
                  logger=True)


In [25]:
rgan = LIT_realnessGAN(
      discriminator_model = Discriminator(IMAGE_DIM, HIDDEN_DIM, NUM_OUTCOMES),
      generator_model = Generator(LATENT_DIM, HIDDEN_DIM, IMAGE_DIM),
      lr=LR)

summary = ModelSummary(rgan, max_depth=-1)
print(summary)

KLD(A0||A1): 2.6878838539123535
   | Name       | Type          | Params
----------------------------------------------
0  | d          | Discriminator | 335 K 
1  | d.model    | Sequential    | 335 K 
2  | d.model.0  | Linear        | 200 K 
3  | d.model.1  | LeakyReLU     | 0     
4  | d.model.2  | Linear        | 65.8 K
5  | d.model.3  | LeakyReLU     | 0     
6  | d.model.4  | Linear        | 65.8 K
7  | d.model.5  | LeakyReLU     | 0     
8  | d.model.6  | Linear        | 2.6 K 
9  | d.model.7  | Softmax       | 0     
10 | g          | Generator     | 406 K 
11 | g.model    | Sequential    | 406 K 
12 | g.model.0  | Flatten       | 0     
13 | g.model.1  | Linear        | 5.4 K 
14 | g.model.2  | BatchNorm1d   | 512   
15 | g.model.3  | ReLU          | 0     
16 | g.model.4  | Linear        | 65.8 K
17 | g.model.5  | BatchNorm1d   | 512   
18 | g.model.6  | ReLU          | 0     
19 | g.model.7  | Linear        | 65.8 K
20 | g.model.8  | BatchNorm1d   | 512   
21 | g.model.9  | R

  rank_zero_warn(
  rank_zero_warn(


## Logger: Proj, Run ... Names

In [26]:
wandb_logger = WandbLogger(project='M5-RealnessGAN', 
                           name='exp-1_40eps',
                           config=CONFIG,
                           job_type='train',
                           log_model="all")

## Trainer Callbacks

In [27]:
from pytorch_lightning import Callback
from lightning.pytorch.callbacks import DeviceStatsMonitor, TQDMProgressBar, ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Checkpoint
checkpoint_callback = ModelCheckpoint(dirpath=CONFIG['checkpoint_path'],
                                      filename='{epoch}-{g_loss:.3f}',
                                      monitor='g_loss',
                                      save_top_k=-1,
                                      save_last=True,
                                      save_weights_only=True,
                                      verbose=True,
                                      mode='min')

# Exp2: Learning Rate Monitor
lr_monitor = LearningRateMonitor(logging_interval='step', log_momentum=False)

# Earlystopping
# earlystopping = EarlyStopping(monitor='val_d_acc', patience=3, mode='min')

## Trainer

In [28]:
trainer = pl.Trainer(fast_dev_run=False,    # For debugging purposes
                     log_every_n_steps=1,   # set the logging frequency
                     accelerator='auto',    # Precedence: tpu > gpu >> cpu
                     devices="auto",        # all
                     max_epochs= NUM_EPOCHS,         # CONFIG['NUM_EPOCHS'],
                     callbacks=[TQDMProgressBar(refresh_rate=25), 
                                checkpoint_callback, 
                                lr_monitor],
                     logger=wandb_logger,    # wandb <3
                     )

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


## Training

In [None]:
trainer.fit(rgan, train_loader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name | Type          | Params
---------------------------------------
0 | d    | Discriminator | 335 K 
1 | g    | Generator     | 406 K 
---------------------------------------
741 K     Trainable params
0         Non-trainable params
741 K     Total params
2.966     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name | Type          | Params
---------------------------------------
0 | d    | Discriminator | 335 K 
1 | g    | Generator     | 406 K 
---------------------------------------
741 K     Trainable params
0         Non-trainable params
741 K     Total params
2.966     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  input = module(input)


## Call Finish on Exp logger

In [None]:
wandb.finish()