# Mounting Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Setting Up Libraries

In [3]:
!pip install opendatasets --q

In [4]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset

from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torchvision.transforms import Compose
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from glob import glob
import imageio
import opendatasets as od

# Parameters + Global Variables

In [6]:
BATCH_SIZE = 128
IMAGES_TO_USE = 63565
IMG_SIZE = 64
NUM_WORKERS = 2
STARTING_INDEX = 0
LATENT_SIZE = 512

MODELS_PATH = "/content/drive/MyDrive/Jovian Project/Models"
GENERATED_IMAGES_PATH = "/content/drive/MyDrive/Jovian Project/Generated"
ANIME_FACES_DATASET_PATH = "/content/animefacedataset"
HUMAN_FACES_DATASET_PATH = "/content/ffhq-face-data-set"

STATS = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

# Downloading Datasets

In [7]:
def download_data(url = 'https://www.kaggle.com/splcher/animefacedataset'):
  od.download(url)

download_data()

Downloading animefacedataset.zip to ./animefacedataset


100%|██████████| 395M/395M [00:12<00:00, 34.3MB/s]





# Utility Functions

In [8]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [9]:
def denorm(img):
  stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
  return img * stats[1][0] + stats[0][0]

def show_images(images, nmax=64):
  fig, ax = plt.subplots(figsize=(8, 8))
  ax.set_xticks([]); ax.set_yticks([])
  ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

In [10]:
def save_samples(index, generator, latent_tensors, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = 'more-latent-generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(GENERATED_IMAGES_PATH, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [11]:
def save_model(model, model_name):
  allowed_models = ["generator.pth", "discriminator_anime.pth", "discriminator_human.pth"]
  if model_name not in allowed_models:
    print("Invalid Model Name")
    return
  else:
    print("Saving ", model_name)
    torch.save(model.state_dict(), os.path.join(MODELS_PATH, model_name))

# Anime Dataloader

In [12]:
def AnimeDataLoader():
  transformations = Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomHorizontalFlip(0.5),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(*STATS)
  ])

  train_ds = ImageFolder(ANIME_FACES_DATASET_PATH, transform=transformations)
  loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

  return loader

def GetSingleBatch(x):
  for i, _ in x:
    return i

# Discriminator

In [13]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.block_1 = self.block(3, 64)
    self.block_2 = self.block(64, 128)
    self.block_3 = self.block(128, 256)
    self.block_4 = self.block(256, 512)

    self.flatter = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False)
    self.flatten = nn.Flatten()
    self.sigmoid = nn.Sigmoid()

  def forward(self, inputs):
    inputs = self.block_1(inputs)
    inputs = self.block_2(inputs)
    inputs = self.block_3(inputs)
    inputs = self.block_4(inputs)
    inputs = self.flatter(inputs)
    inputs = self.flatten(inputs)
    inputs = self.sigmoid(inputs)

    return inputs


  def block(self, in_channels, out_channels):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace=True),
    ]
    
    return nn.Sequential(*layers)

# Generator

In [14]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.generator = nn.Sequential(
        self.block(LATENT_SIZE, 512, stride=1, padding=0),
        self.block(512, 256),
        self.block(256, 128),
        self.block(128, 64),
        nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
        nn.Tanh(),
    )

  def forward(self, inputs):
    return self.generator(inputs)

  def block(self, in_channels, out_channels, stride=2, padding=1):
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, padding=padding, stride=stride, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]

    return nn.Sequential(*layers)

# Baby Steps

1. Make the GAN work for anime faces.
2. Make the GAN work for human faces.
3. Implementing gradient clipping and weight decay for the generator (on one dataset only).
4. Make the GAN work for both.
5. Change the architecture of GAN.

# Experimentation

In [15]:
device = get_default_device()
device

device(type='cuda')

In [16]:
netD = Discriminator()
netG = Generator()

# netD.load_state_dict(torch.load(os.path.join(MODELS_PATH, "discriminator_anime.pth")))
# netG.load_state_dict(torch.load(os.path.join(MODELS_PATH, "generator.pth")))

netD = to_device(netD, device)
netG = to_device(netG, device)

In [17]:
optimizerD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.999))

In [18]:
anime_loader = AnimeDataLoader()
anime_loader = DeviceDataLoader(anime_loader, device)

In [19]:
criterion = nn.BCELoss()

In [25]:
for epoch in range(25):
  losses_disc = []
  losses_gen = []
  for anime, _ in tqdm(anime_loader):
    real_targets = torch.ones(anime.size(0), 1, device=device)
    fake_targets = torch.zeros(anime.size(0), 1, device=device)

    # Update D with real data
    netD.zero_grad()
    real_preds = netD(anime)
    loss_real = criterion(real_preds, real_targets)
    loss_real.backward()

    # Update D with fake data
    latent = torch.randn(anime.size(0), LATENT_SIZE, 1, 1, device=device)
    fake_images = netG(latent)
    
    fake_preds = netD(fake_images.detach()) # .detach is compulsory for backprop!
    loss_fake = criterion(fake_preds, fake_targets)
    loss_fake.backward()
    optimizerD.step()

    losses_disc.append(loss_real.item() + loss_fake.item())

    # Update G with fake data
    netG.zero_grad()
    fake_preds_r = netD(fake_images)
    loss_g = criterion(fake_preds_r, real_targets)
    loss_g.backward()
    optimizerG.step()

    losses_gen.append(loss_g.item())

  new_latent = torch.randn(BATCH_SIZE, LATENT_SIZE, 1, 1, device=device)
  save_samples(epoch+25, netG, new_latent, show=False)
  losses_disc = np.array(losses_disc)
  losses_gen = np.array(losses_gen)

  print(f"Loss Disc: {losses_disc.mean()} | Loss Gen: {losses_gen.mean()}")

  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0025.png
Loss Disc: 0.1184992329603926 | Loss Gen: 7.549607227505812


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0026.png
Loss Disc: 0.09365893738393237 | Loss Gen: 7.465178939178436


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0027.png
Loss Disc: 0.11469843216419086 | Loss Gen: 8.135081278246414


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0028.png
Loss Disc: 0.10081649195780108 | Loss Gen: 7.201686419711506


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0029.png
Loss Disc: 0.09845530137011524 | Loss Gen: 8.046286297036369


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0030.png
Loss Disc: 0.10389739647244936 | Loss Gen: 7.982014444512381


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0031.png
Loss Disc: 0.10858059521961665 | Loss Gen: 7.463297779171519


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0032.png
Loss Disc: 0.08561065091865494 | Loss Gen: 8.177092111326559


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0033.png
Loss Disc: 0.09538833914460888 | Loss Gen: 8.277724314503507


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0034.png
Loss Disc: 0.08802355632252719 | Loss Gen: 8.177424816538391


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0035.png
Loss Disc: 0.10077578759238787 | Loss Gen: 7.443819753121082


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0036.png
Loss Disc: 0.07770182722507957 | Loss Gen: 8.000115353336756


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0037.png
Loss Disc: 0.09009428251150677 | Loss Gen: 9.183626091456269


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0038.png
Loss Disc: 0.10411554468431994 | Loss Gen: 7.763638985708686


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0039.png
Loss Disc: 0.08680149366167934 | Loss Gen: 7.727128451019226


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0040.png
Loss Disc: 0.08345382745395676 | Loss Gen: 7.504169579722752


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0041.png
Loss Disc: 0.09157579173501643 | Loss Gen: 8.680188257689448


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0042.png
Loss Disc: 0.07856753355687929 | Loss Gen: 7.977927175326126


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0043.png
Loss Disc: 0.10642238437303565 | Loss Gen: 8.948789401553286


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0044.png
Loss Disc: 0.08559308871893462 | Loss Gen: 7.251653178836738


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0045.png
Loss Disc: 0.059333935189793297 | Loss Gen: 7.063725275772678


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0046.png
Loss Disc: 0.10417081180830516 | Loss Gen: 10.530666768191086


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0047.png
Loss Disc: 0.062653462676996 | Loss Gen: 7.033066573996899


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0048.png
Loss Disc: 0.08437154398576723 | Loss Gen: 10.139918416558617


  0%|          | 0/497 [00:00<?, ?it/s]

Saving more-latent-generated-images-0049.png
Loss Disc: 0.0877897756130154 | Loss Gen: 9.069741574331548


In [26]:
save_model(netG, "generator.pth")

Saving  generator.pth


In [27]:
save_model(netD, "discriminator_anime.pth")

Saving  discriminator_anime.pth
