Inspiration from the DCGAN pytorch tutorial

In [None]:
import torch
import torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
%matplotlib inline

##Dataset
Using the Anime kaggle dataset

In [None]:
train_dir='./animefacedataset'
print(os.listdir(train_dir+'/images')[:10])

##Hyperparameters
Keeping important hyperparameters, may also be stored in a config file using hydra

In [None]:
#Important parameters according to the paper
lr=0.0002 #learning rate
batch_size=128 #batch size
beta_1=0.5 #momentum beta1
beta_2=0.999 #momentum beta2
slope=0.2 #Leaky ReLU
num_epochs=30 #Number of epochs
image_size=64 #Image size of inputs
random_seed=35 #Seed for random generation for reproducibility
n=30080 #Number of pictures to be takes
device=torch.device('cuda') #Using CUDA device
noise=100 #Noise dimension

##Dataset, DataLoader

In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torchvision.utils import make_grid

In [None]:
transforms=T.Compose([
                      T.ToTensor(),
                      T.Resize((image_size, image_size)),
                      T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #Bringing images to (-1,1) 
])

Since about 60k+ images are present, only 30k+ images are taken to reduce training time

In [None]:
np.random.seed(random_seed)
data = ImageFolder(train_dir, transform=transforms)
train_data=Subset(data, np.random.choice(len(data), n, replace=False))
train = DataLoader(train_data, batch_size, shuffle=True, num_workers=2, pin_memory=True)

##Model  for DCGAN
First Generator:
1. Transpose Conv2D
2. BatchNorm 
3. ReLU, (but Tanh for the last layer to convert image to (-1,1) )

In [None]:
import torch.nn as nn
import torch.optim as optim

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.main=nn.Sequential(
        self.gen_layer(noise,512,4,1,0),
        self.gen_layer(512,256,4,2,1),
        self.gen_layer(256,128,4,2,1),
        self.gen_layer(128,64,4,2,1),
        nn.ConvTranspose2d(in_channels=64, out_channels=3,
                             kernel_size=4, stride=2, padding=1),
        nn.Tanh())

  def gen_layer(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
          nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                             kernel_size=kernel_size, stride=stride, padding=padding),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(False))
      
  def forward(self, x):
    return self.main(x)

Generator()

Then Discriminator:
1. Conv2D
2. BatchNorm (only for the middle layers)
3. Leaky ReLU (but Sigmoid in the last layer for class probabilities)

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.main=nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=64,
                             kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(slope),
        self.disc_block(64,128,4,2,1),
        self.disc_block(128,256,4,2,1),
        self.disc_block(256,512,4,2,1),
        nn.Conv2d(in_channels=512, out_channels=1,
                             kernel_size=4, stride=1, padding=0),
        nn.Sigmoid()
    )
  
  def disc_block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
          nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                             kernel_size=kernel_size, stride=stride, padding=padding),
          nn.BatchNorm2d(out_channels),
          nn.LeakyReLU(slope)) #taking the slope from the previous set values
  
  def forward(self, x):
    return self.main(x)

Discriminator()

Initiating weights for Generator and Discriminator

In [None]:
def initialise_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
      nn.init.normal_(m.weight.data, 0, 0.2)
    elif isinstance(m, nn.BatchNorm2d):
      nn.init.normal_(m.weight.data, 1.0, 0.02)
      nn.init.constant_(m.bias.data, 0)

netG=Generator().apply(initialise_weights)
netD=Discriminator().apply(initialise_weights)

Choosing other important metrics such as BCELoss and Adam optimizer

In [None]:
criterion=nn.BCELoss()
fixed_random_noise=torch.randn(batch_size, noise, 1,1)
optim_D=optim.Adam(netD.parameters(), lr=lr, betas=(beta_1, beta_2))
optim_G=optim.Adam(netG.parameters(), lr=lr, betas=(beta_1, beta_2))

##Training Loop

In [None]:
netG.to(device)
netD.to(device)
for epoch in range(num_epochs):
  for idx, (img,_) in enumerate(train):
    netG.train()
    netD.train()
    
    #Training Discriminant maximise loss= log D(x) + log (1-D(G(z))) 
    #Passing Real images to discriminant
    img=img.to(device)
    real_out=netD(img).view(-1)
    real_labels=torch.full_like(real_out, 0.95).to(device) #Instead of using 1, we can use 0.95 to improve the training
    loss_real=criterion(real_out, real_labels.detach())
    
    #Passing Fake images to the Discriminator, after passing through the generator
    fixed_random_noise=fixed_random_noise.to(device)
    fake_imgs=netG(fixed_random_noise) #Generate the fake images
    fake_out=netD(fake_imgs.detach()).view(-1)
    fake_labels=torch.full_like(fake_out, 0.05).to(device) #Instead of using 0, we can use 0.05 to improve the training
    loss_fake=criterion(fake_out, fake_labels.detach())

    loss_d=(loss_real + loss_fake)/2

    netD.zero_grad()
    loss_d.backward()
    optim_D.step()

    #Training Generator maximise loss=log D(G(z))
    fake_img_gen=netD(fake_imgs).view(-1)
    make_it_real=torch.ones_like(fake_img_gen).to(device)
    gen_loss=criterion(fake_img_gen, make_it_real.detach())

    netG.zero_grad()
    gen_loss.backward()
    optim_G.step()


  print('Epoch',epoch+1)

torch.save(netG.state_dict(), 'G.pth')
torch.save(netD.state_dict(), 'D.pth')

In [None]:
with torch.no_grad():
  plt.figure(figsize=(8,8))
  fake=netG(fixed_random_noise).cpu()
  plt.imshow(np.transpose(make_grid(fake[:64], padding=2, normalize=True).cpu(),(1,2,0)))