# Pokemon Generation using GANs [WGAN]

## Imports

In [0]:
from __future__ import print_function
#%matplotlib inline
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

## Connect to drive

In order to use pokemon data you need to get data from [this Github project](https://github.com/rileynwong/pokemon-images-dataset-by-type).\
Then you need to convert them to `.png` and save them in a file nammed `data` in your Google Drive.\
**Do not forget to change the `prefix` variable to the path to your data.**

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')
prefix = "gdrive/My Drive/ECP/DL"

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


## Parameters input

In [3]:
# Set the seed manually or randomly

seed = 999 # Use 'random.randint(1, 10000)' to deactivate reproductability
print("seed: ", seed)
random.seed(seed)
torch.manual_seed(seed)

seed:  999


<torch._C.Generator at 0x7f0734c32a90>

In [0]:
# Size of z latent vector# Root directory of the dataset
PATH_DATA = os.path.join(prefix, "data")
PATH_NOTEBOOK = os.path.join(prefix, "GAN_WASSERSTEIN")
# Number of channels per image
IMG_N_CHANNELS = 3 #@param {type: "integer", default: 3}
# Rescale size before random crop
IMG_RESCALE_SIZE = 160 #@param {type: "integer", default: 180}
# Final size of the image
IMG_SIZE = 128 #@param {type: "integer", default: 128}
# Number of workers for dataloader
N_WORKERS =  8#@param {type: "integer", default: 2}

# Size of the latent vector
LATENT_VECTOR_SIZE =  64 #@param {type:'integer', default: 64}
# Size of the base feature map in generator
G_FEATURE_SIZE = 256 #@param {type:'integer', default: 64}
# Size of the base feature map in discriminator
D_FEATURE_SIZE =  256#@param {type:'integer', default: 64}

# Batch size during training
BATCH_SIZE = 64 #@param {type:'integer', default: 64}
# Number of training epochs
NUM_EPOCHS = 200 #@param {type:'integer', default: 300}
# Learning rate for optimizers
LEARNING_RATE = 0.0002  #@param {type:'number', default: 0.00005}
# Number of training iterations for the discriminator befor training the generator
D_N_TRAIN =  5# @param {type:'integer', default:5}
# Clip value for weights
CLIP_VALUE = 0.01 #@param {type:'number', default: 0.01}
# Adam optimizer parameters
BETA1 = 0.5 #@param {type: 'number', default: 0.5}
BETA2 = 0.999 #@param {type: 'number', default: 0.999}

## Data preparation

In [0]:
# We use a `torch.utils.data.DataLoader` to handle the data. That way we can easily apply random transformation to each batch at every epoch.

# Create the dataset
dataset = datasets.ImageFolder(root=PATH_DATA,
                           transform=transforms.Compose([
                               transforms.Resize(IMG_RESCALE_SIZE),
                               transforms.RandomCrop(IMG_SIZE),
                               transforms.RandomHorizontalFlip(),
                               transforms.RandomRotation(20, fill=(255, 255, 255)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=N_WORKERS)


In [6]:
# Uses the Colab GPU as a device if available

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [0]:
# Plots a sample of the processed data used for the training.

real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:32], padding=2, normalize=True).cpu(),(1,2,0)))

<matplotlib.image.AxesImage at 0x7f06e0414a90>

## Models definition

In [0]:
# Generator definition

class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    def conv_block(in_features, out_features, normalize=True):
      layers = [nn.ConvTranspose2d(in_features, out_features, 3, 2, 1, 1, bias=False)]
      if normalize:
        layers.append(nn.BatchNorm2d(out_features, 0.8))
      layers.append(nn.LeakyReLU(0.2, inplace=True))
      return layers

    self.conv_blocks = nn.Sequential(
        *conv_block(G_FEATURE_SIZE, G_FEATURE_SIZE // 2),
        *conv_block(G_FEATURE_SIZE // 2, G_FEATURE_SIZE // 4),
        *conv_block(G_FEATURE_SIZE // 4 , G_FEATURE_SIZE // 8),
        *conv_block(G_FEATURE_SIZE // 8 , G_FEATURE_SIZE // 16),
    )

    self.in_block = nn.Sequential(
        nn.Linear(LATENT_VECTOR_SIZE, G_FEATURE_SIZE // 2 * 4 * 4, bias=False),
        nn.BatchNorm1d(G_FEATURE_SIZE // 2 * 4 * 4, 0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(G_FEATURE_SIZE // 2 * 4 * 4, G_FEATURE_SIZE * 4 * 4, bias=False),
        nn.BatchNorm1d(G_FEATURE_SIZE * 4 * 4, 0.8),
        nn.LeakyReLU(0.2, inplace=True)
    )

    self.out_block = nn.Sequential(
        nn.ConvTranspose2d(G_FEATURE_SIZE // 16, 3, 3, 2, 1, 1, bias=False),
        nn.Tanh()
    )

  def forward(self, z):
    z = z.view(-1, LATENT_VECTOR_SIZE)
    x = self.in_block(z)
    x = x.view(-1, G_FEATURE_SIZE, 4, 4)
    x = self.conv_blocks(x)
    img = self.out_block(x)
    return img

In [0]:
# Discriminator definition

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    def conv_block(in_features, out_features, normalize=True):
      layers = [nn.Conv2d(in_features, out_features, 3, 2, 1, bias=True)]
      if normalize:
        layers.append(nn.BatchNorm2d(out_features, 0.8))
      layers.append(nn.LeakyReLU(0.2, inplace=True))
      return layers

    self.model = nn.Sequential(
        *conv_block(IMG_N_CHANNELS, D_FEATURE_SIZE // 2),
        *conv_block(D_FEATURE_SIZE // 2 , D_FEATURE_SIZE // 4),
        *conv_block(D_FEATURE_SIZE // 4, D_FEATURE_SIZE // 8),
        *conv_block(D_FEATURE_SIZE // 8, D_FEATURE_SIZE // 16),
        *conv_block(D_FEATURE_SIZE // 16, D_FEATURE_SIZE // 32),
    )

    self.out_block = nn.Sequential(
        nn.Linear(D_FEATURE_SIZE // 32 * 4 * 4, 1)
    )

  def forward(self, img):
    x = self.model(img)
    x = x.view(-1, D_FEATURE_SIZE // 32 * 4 * 4)
    out = self.out_block(x)
    return out

## Training

In [0]:
# Creates the generator and the discriminator

generator = Generator().to(device)
discriminator = Discriminator().to(device)

print(generator)
print(discriminator)

In [0]:
# Creates the optimizers

optimizer_G = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))

In [0]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.randn(real_samples.size(0), 1, 1, 1, device=device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.full((real_samples.shape[0], 1), 1.0, device=device, requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [0]:
# Training Loop

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, LATENT_VECTOR_SIZE, 1, 1, device=device)
img_list = []

batches_done = 0
print("Starting Training Loop...")

# For each epoch
for epoch in range(NUM_EPOCHS):

  for i, (imgs, _) in enumerate(dataloader, 0):

    # Configure input
    real_imgs = imgs.to(device)
    b_size = real_imgs.size(0)
    # ---------------------
    #  Train Discriminator
    # ---------------------

    optimizer_D.zero_grad()

    # Sample noise as generator input
    z = torch.randn(b_size, LATENT_VECTOR_SIZE, 1, 1, device=device)

    # Generate a batch of images
    fake_imgs = generator(z)

    real_validity = discriminator(real_imgs)
    fake_validity = discriminator(fake_imgs)
    # compute graient penalty
    gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
    # Adversarial loss
    loss_D = -torch.mean(real_validity) + torch.mean(fake_validity) + 10 * gradient_penalty
    loss_D.backward()
    optimizer_D.step()

    optimizer_G.zero_grad()

    # Train the generator every n_critic iterations
    if i % D_N_TRAIN == 0:
      # -----------------
      #  Train Generator
      # -----------------

      
      # Generate a batch of images
      gen_imgs = generator(z)
      # Adversarial loss
      loss_G = -torch.mean(discriminator(gen_imgs))

      loss_G.backward()
      optimizer_G.step()

      print(
        "[Epoch %d/%d][Batch %d/%d][D loss: %f] [G loss: %f]"
        % (epoch, NUM_EPOCHS, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
      )   
      batches_done += D_N_TRAIN

  # Check how the generator is doing by saving G's output on fixed_noise
  if (epoch % (NUM_EPOCHS / 50) == 0) or ((epoch == NUM_EPOCHS - 1) and (i == len(dataloader) - 1)):
    with torch.no_grad():
      fake = generator(fixed_noise).detach().cpu()
      img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

## Results

In [0]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [0]:
#%%capture
fig = plt.figure(figsize=(15,15))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [0]:
# Save images
#plt.figure(figsize=(15,15))
#plt.imshow(np.transpose(img_list[-1],(1,2,0)))
#plt.savefig(os.path.join(PATH_NOTEBOOK, "wow.png"))

In [0]:
# Generate random pokemons
num_pokemons = 3
latent_vectors = torch.randn(num_pokemons, LATENT_VECTOR_SIZE, 1, 1, device=device)
pokemons = generator(latent_vectors).detach().cpu()

for i in range(num_pokemons):
  plt.figure()
  plt.imshow(np.transpose(vutils.make_grid(pokemons.to(device)[i], padding=5, normalize=True).cpu(),(1,2,0)))