Insipired by the algorithm in the paper and [this github code](https://github.com/arturml/pytorch-wgan-gp/blob/master/wgangp.py)

In [None]:
import torch
import torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
%matplotlib inline

##Dataset
Using the Anime kaggle dataset

In [None]:
train_dir='./animefacedataset'
print(os.listdir(train_dir+'/images')[:10])

##Hyperparameters
Keeping important hyperparameters, may also be stored in a config file using hydra

In [None]:
#Important parameters according to the paper
lr=0.0001 #learning rate
batch_size=64 #batch size
beta_1=0.2 #momentum beta1
beta_2=0.999 #momentum beta2
slope=0.2 #Leaky ReLU
num_epochs=50 #Number of epochs
image_size=64 #Image size of inputs
random_seed=35 #Seed for random generation for reproducibility
n=30080 #Number of pictures to be takes
device=torch.device('cuda') #Using CUDA device
noise=128 #Noise dimension
Lambda=10 #Gradient penalty 
n_critic=5 #No. of steps critc has to take before training Generator

##Dataset, Dataloader


In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torchvision.utils import make_grid

In [None]:
transforms=T.Compose([
                      T.ToTensor(),
                      T.Resize((image_size, image_size)),
                      T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #Bringing images to (-1,1) 
])

Since about 60k+ images are present, only 30k+ images are taken to reduce training time

In [None]:
np.random.seed(random_seed)
data = ImageFolder(train_dir, transform=transforms)
train_data=Subset(data, np.random.choice(len(data), n, replace=False))
train = DataLoader(train_data, batch_size, shuffle=True, num_workers=2, pin_memory=True)

##Model for WGAN

In [None]:
import torch.nn as nn
import torch.optim as optim

First Generator:
1. Transpose Conv2D
2. BatchNorm
3. ReLU, (but Tanh for the last layer to convert image to (-1,1) )


In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.main=nn.Sequential(
        self.gen_layer(noise,512,4,1,0),
        self.gen_layer(512,256,4,2,1),
        self.gen_layer(256,128,4,2,1),
        self.gen_layer(128,64,4,2,1),
        nn.ConvTranspose2d(in_channels=64, out_channels=3,
                             kernel_size=4, stride=2, padding=1),
        nn.Tanh())

  def gen_layer(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
          nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                             kernel_size=kernel_size, stride=stride, padding=padding),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(False))
      
  def forward(self, x):
    return self.main(x)

Generator()

Then Discriminator:
1. Conv2D
2. InstanceNorm2D instead of BatchNorm (only for the middle layers)
3. Leaky ReLU (no Sigmoid)

In [None]:
class Critic(nn.Module):
  def __init__(self):
    super(Critic, self).__init__()
    self.main=nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=64,
                             kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(slope),
        self.disc_block(64,128,4,2,1),
        self.disc_block(128,256,4,2,1),
        self.disc_block(256,512,4,2,1),
        nn.Conv2d(in_channels=512, out_channels=1,
                             kernel_size=4, stride=1, padding=0)
    )
  
  def disc_block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
          nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                             kernel_size=kernel_size, stride=stride, padding=padding),
          nn.InstanceNorm2d(out_channels),
          nn.LeakyReLU(slope)) #taking the slope from the previous set values
  
  def forward(self, x):
    return self.main(x)

Critic()

Initiating weights for Generator and Discriminator

In [None]:
def initialise_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
      nn.init.normal_(m.weight.data, 0, 0.2)
    elif isinstance(m, nn.BatchNorm2d):
      nn.init.normal_(m.weight.data, 1.0, 0.02)
      nn.init.constant_(m.bias.data, 0)

netG=Generator().apply(initialise_weights)
netC=Critic().apply(initialise_weights)

In [None]:
#Other important choices to be made
optim_C=optim.Adam(netC.parameters(), lr=lr, betas=(beta_1, beta_2))
optim_G=optim.Adam(netG.parameters(), lr=lr, betas=(beta_1, beta_2))
fixed_random_noise=torch.randn(batch_size, noise, 1,1)

##Training Loop

In [None]:
netC.to(device)
netG.to(device)
for epoch in range(num_epochs):
  for idx, (img, _) in enumerate(train):
    netG.train()
    netC.train()
    
    for i in range(n_critic):
      img=img.to(device)
      fixed_random_noise=fixed_random_noise.to(device)

      #Fake generation and finding the critic scores
      fake_imgs=netG(fixed_random_noise)
      fake_scores=netC(fake_imgs).view(-1)
      #Finding the scores of the real images
      real_scores=netC(img).view(-1)
      #using the interpolation scheme to find the interpolated image
      epsilon=torch.rand((batch_size,1,1,1))
      epsilon=epsilon.expand_as(img).to(device)
      interpolation=epsilon*img + (1-epsilon)*fake_imgs
      new_scores=netC(interpolation) #Finding the scores of the interpolated images
      
      #Finding the gradient of the interpolated scores wrt to the interpolated image
      interpolated_grad=torch.autograd.grad(
          inputs=interpolation,
          outputs=new_scores,
          grad_outputs=torch.ones_like(new_scores),
          retain_graph=True,
          create_graph=True
      )[0]
      grad_inter=interpolated_grad.view(interpolated_grad.shape[0], -1)
      inter_avg=torch.mean(((grad_inter.norm(2, dim=1)-1)**2))

      #Finding the new average loss for the critic
      avg_criloss=-(torch.mean(fake_scores)- torch.mean(real_scores) + (Lambda*inter_avg))

      netC.zero_grad()
      avg_criloss.backward(retain_graph=True)
      optim_C.step()
    
    #Training Generator maximise loss=log D(G(z))
    fake_checking=netC(fake_imgs).view(-1)
    avg_genloss=-torch.mean(fake_checking)

    netG.zero_grad()
    avg_genloss.backward()
    optim_G.step()


  print('Epoch',epoch+1)

torch.save(netG.state_dict(), 'G.pth')
torch.save(netC.state_dict(), 'C.pth')

In [None]:
with torch.no_grad():
  plt.figure(figsize=(8,8))
  fake=netG(fixed_random_noise).cpu()
  plt.imshow(np.transpose(make_grid(fake[:64], padding=2, normalize=True).cpu(),(1,2,0)))