# Mounting Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Setting Up Libraries

In [2]:
!pip install opendatasets --q

In [3]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset

from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torchvision.transforms import Compose
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from glob import glob
import imageio
import opendatasets as od

# Parameters + Global Variables

In [4]:
BATCH_SIZE = 128
IMAGES_TO_USE = 63565
IMG_SIZE = 64
NUM_WORKERS = 2
STARTING_INDEX = 0
LATENT_SIZE = 128

MODELS_PATH = "/content/drive/MyDrive/Jovian Project/Models"
GENERATED_IMAGES_PATH = "/content/drive/MyDrive/Jovian Project/Generated"
ANIME_FACES_DATASET_PATH = "/content/animefacedataset"
HUMAN_FACES_DATASET_PATH = "/content/ffhq-face-data-set"

STATS = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

# Downloading Datasets

In [5]:
def download_data(url = 'https://www.kaggle.com/splcher/animefacedataset'):
  od.download(url)

download_data()

Downloading animefacedataset.zip to ./animefacedataset


100%|██████████| 395M/395M [00:02<00:00, 181MB/s]





### TBD: Downloading human data.

# Utility Functions

In [6]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [7]:
def denorm(img):
  stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
  return img * stats[1][0] + stats[0][0]

def show_images(images, nmax=64):
  fig, ax = plt.subplots(figsize=(8, 8))
  ax.set_xticks([]); ax.set_yticks([])
  ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

In [8]:
def save_samples(index, generator, latent_tensors, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(GENERATED_IMAGES_PATH, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [9]:
def save_model(model, model_name):
  allowed_models = ["generator.pth", "discriminator_anime.pth", "discriminator_human.pth"]
  if model_name not in allowed_models:
    print("Invalid Model Name")
    return
  else:
    print("Saving ", model_name)
    torch.save(model.state_dict(), os.path.join(MODELS_PATH, model_name))

# Anime Dataloader

In [10]:
def AnimeDataLoader():
  transformations = Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomHorizontalFlip(0.5),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(*STATS)
  ])

  train_ds = ImageFolder(ANIME_FACES_DATASET_PATH, transform=transformations)
  loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

  return loader

def GetSingleBatch(x):
  for i, _ in x:
    return i

# Human DataLoader

### TBD

# Discriminator

In [12]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.block_1 = self.block(3, 64)
    self.block_2 = self.block(64, 128)
    self.block_3 = self.block(128, 256)
    self.block_4 = self.block(256, 512)

    self.flatter = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False)
    self.flatten = nn.Flatten()
    self.sigmoid = nn.Sigmoid()

  def forward(self, inputs):
    inputs = self.block_1(inputs)
    inputs = self.block_2(inputs)
    inputs = self.block_3(inputs)
    inputs = self.block_4(inputs)
    inputs = self.flatter(inputs)
    inputs = self.flatten(inputs)
    inputs = self.sigmoid(inputs)

    return inputs


  def block(self, in_channels, out_channels):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace=True),
    ]
    
    return nn.Sequential(*layers)

# Generator

In [13]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.generator = nn.Sequential(
        self.block(LATENT_SIZE, 512, stride=1, padding=0),
        self.block(512, 256),
        self.block(256, 128),
        self.block(128, 64),
        nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
        nn.Tanh(),
    )

  def forward(self, inputs):
    return self.generator(inputs)

  def block(self, in_channels, out_channels, stride=2, padding=1):
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, padding=padding, stride=stride, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]

    return nn.Sequential(*layers)

# Discriminator Training

In [24]:
def training_discriminator(discriminator, generator, optimizer_discriminator, real_images, device):
  optimizer_discriminator.zero_grad()
  curr_batch_size = real_images.shape[0]

  real_predictions = discriminator(real_images)
  real_targets = torch.ones(curr_batch_size, 1, device=device)
  real_loss = F.binary_cross_entropy(real_predictions, real_targets)
  real_score = real_predictions.mean().item()

  latent = torch.randn(curr_batch_size, LATENT_SIZE, 1, 1, device=device)
  fake_images = generator(latent)

  fake_predictions = discriminator(fake_images)
  fake_targets = torch.zeros(curr_batch_size, 1, device=device)
  fake_loss = F.binary_cross_entropy(fake_predictions, fake_targets)
  fake_score = fake_predictions.mean().item()

  loss = fake_loss + real_loss

  loss.backward()
  optimizer_discriminator.step()

  return loss.item(), real_score, fake_score

# Generator Training

In [25]:
def training_generator(generator, discriminator, optimizer_generator, fixed_latent, device):
  optimizer_generator.zero_grad()

  fake_images = generator(fixed_latent)
  curr_batch_size = fixed_latent.shape[0]

  fake_predictions = discriminator(fake_images)
  fake_targets = torch.ones(curr_batch_size, 1, device=device)
  fake_loss = F.binary_cross_entropy(fake_predictions, fake_targets)

  fake_loss.backward()
  optimizer_generator.step()

  return fake_loss.item()

# Training Loop

In [40]:
def training_loop(retrain=False):
  # Setting up device
  device = get_default_device()
  print("Device: ", device)

  # Setting up models
  generator = Generator()
  discriminator_anime = Discriminator()

  # Setting up loaders
  anime_loader = AnimeDataLoader()

  # (Re)Loading models
  if retrain == True:
    print("Loading models")
    generator.load_state_dict(torch.load(os.path.join(MODELS_PATH, "generator.pth")))
    discriminator_anime.load_state_dict(torch.load(os.path.join(MODELS_PATH, "discriminator_anime.pth")))

  # Shifting to device
  generator = to_device(generator, device)
  discriminator_anime = to_device(discriminator_anime, device)

  anime_loader = DeviceDataLoader(anime_loader, device)

  # Fixed latent
  fixed_latent = torch.randn(BATCH_SIZE, LATENT_SIZE, 1, 1, device=device)

  def fit(epochs, lr, starting_index=0):
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    anime_optimizer = torch.optim.Adam(discriminator_anime.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
      for anime, _ in tqdm(anime_loader):
        anime_loss, real_anime_score, fake_anime_score = training_discriminator(
          discriminator_anime,
          generator,
          anime_optimizer,
          anime,
          device,
        )
        generator_loss = training_generator(
            generator,
            discriminator_anime,
            generator_optimizer,
            fixed_latent,
            device,
        )
      print(f"Epoch: {epoch}")
      print(f"Discriminator Anime Loss: {anime_loss} | Discirminator Anime Score: {real_anime_score} | Generator Score From Anime: {fake_anime_score}")
      print(f"Generator Loss: {generator_loss}\n")

      save_samples(epoch+STARTING_INDEX, generator, fixed_latent, show=False)

  def save_all_models():
    print("\n\nSaving All Models")
    save_model(generator, "generator.pth")
    save_model(discriminator_anime, "discriminator_anime.pth")
    print("\n")

  fit(30, 0.0001)
  save_all_models()

In [42]:
training_loop()

Device:  cuda


  0%|          | 0/497 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

# Baby Steps

1. Make the GAN work for anime faces.
2. Make the GAN work for human faces.
3. Implementing gradient clipping and weight decay for the generator (on one dataset only).
4. Make the GAN work for both.
5. Change the architecture of GAN.

# Experimentation

In [14]:
device = get_default_device()
device

device(type='cuda')

In [15]:
netD = Discriminator()
netG = Generator()

netD = to_device(netD, device)
netG = to_device(netG, device)

In [16]:
optimizerD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))

In [17]:
anime_loader = AnimeDataLoader()
anime_loader = DeviceDataLoader(anime_loader, device)

In [18]:
criterion = nn.BCELoss()

In [25]:
for epoch in range(25):
  for anime, _ in tqdm(anime_loader):
    real_targets = torch.ones(anime.size(0), 1, device=device)
    fake_targets = torch.zeros(anime.size(0), 1, device=device)

    # Update D with real data
    netD.zero_grad()
    real_preds = netD(anime)
    loss_real = criterion(real_preds, real_targets)
    loss_real.backward()

    # Update D with fake data
    latent = torch.randn(anime.size(0), LATENT_SIZE, 1, 1, device=device)
    fake_images = netG(latent)
    
    fake_preds = netD(fake_images.detach()) # .detach is compulsory for backprop!
    loss_fake = criterion(fake_preds, fake_targets)
    loss_fake.backward()
    optimizerD.step()

    # Update G with fake data
    netG.zero_grad()
    fake_preds_r = netD(fake_images)
    loss_g = criterion(fake_preds_r, real_targets)
    loss_g.backward()
    optimizerG.step()

  save_samples(epoch+50, netG, latent, show=False)

  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0050.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0051.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0052.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0053.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0054.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0055.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0056.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0057.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0058.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0059.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0060.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0061.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0062.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0063.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0064.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0065.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0066.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0067.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0068.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0069.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0070.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0071.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0072.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0073.png


  0%|          | 0/497 [00:00<?, ?it/s]

Saving generated-images-0074.png


In [26]:
save_model(netG, "generator.pth")

Saving  generator.pth


In [27]:
save_model(netD, "discriminator_anime.pth")

Saving  discriminator_anime.pth
