# PokeGAN

In [None]:
# Execute this to save new versions of the notebook
import jovian
jovian.commit(filename="PokeGAN.ipynb")

# Github Repository

I have a full GitHub repository available for this project. I found my workflow to be quite a bit smoother using the run.py script available in the repository, but this notebook offers the same functionality if you prefer. 

To clone the repository to your machine use the following bash command: 

`$ git clone https://github.com/Kodlak15/PokeGAN`

Then navigate to the local repository in your terminal, and enter the following command:

`$ python3 run.py`

You will be prompted for the number of epochs you would like to train for as well as the learning rate. 

# Device Management

In [None]:
import torch

In [None]:
def get_default_device():
    """ Pick GPU if available, else CPU """
    if torch.cuda.is_available():
        return torch.device('cuda')
    
    else:
        return torch.device('cpu')

def to_device(data, device):
    """ Move tensor(s) to chosen device """
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    
    if isinstance(data, dict):
        return {k: to_device(t, device) for k, t in data.items()}
    
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """ Wrap a dataloader to move data to a device """
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """ Yield a batch of data after moving it to device """
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """ Number of batches """
        return len(self.dl)

# Utilities 

In [None]:
from typing import Tuple, Union
from torch import Tensor
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
from PIL import Image
import os
from os.path import join
import cv2
import json
import numpy as np
import re

In [None]:
def format_path(path: str):
    return path.replace("\\", '/')

def prepare_paths():
    root_dir = os.getcwd()
    img_dir = format_path(join(root_dir, "images"))
    train_dir = format_path(join(img_dir, "train-images"))
    fake_dir = join(root_dir, "fakes")
    weights_dir = join(root_dir, "weights")
    history_dir = join(root_dir, "history")

    with open("paths.txt", 'w') as f:
        f.write(root_dir + '\n')
        f.write(img_dir + '\n')
        f.write(train_dir + '\n')
        f.write(fake_dir + '\n')
        f.write(weights_dir + '\n')
        f.write(history_dir)

    if not "images" in os.listdir(root_dir):
        os.mkdir(img_dir)

    if not "train-images" in os.listdir(img_dir):
        os.mkdir(join(img_dir, "train-images"))

    if not "fakes" in os.listdir(root_dir):
        os.mkdir(fake_dir)

    if not "weights" in os.listdir(root_dir):
        os.mkdir(weights_dir)

    if not "history" in os.listdir(root_dir):
        os.mkdir(history_dir)
        with open(join(history_dir, "history.json"), 'w') as f:
            json.dump({"history": []}, f)

def get_paths():
    prepare_paths()
    with open("paths.txt", 'r') as f:
        return [path.replace('\n', '') for path in f.readlines()]

root_dir, img_dir, train_dir, fake_dir, weights_dir, history_dir = get_paths()

In [None]:
train_stats = [0.1874, 0.1779, 0.1681], [1.0, 1.0, 1.0]

@torch.no_grad()
def show_images(batch: Union[DataLoader, Tensor]):
    """
    Takes a tensor (B, C, W, H) or dataloader as input and displays a batch of training images
    """
    for images in batch:
        fig, ax = plt.subplots(figsize=(32,32))
        ax.set_xticks([]); ax.set_yticks([])
        images = images.to("cpu")
        images = denormalize(images, *train_stats)
        ax.imshow(make_grid(images, nrow=8).permute(1, 2, 0))
        break

@torch.no_grad()
def show_fakes(images, num_to_show=64):
    """
    Displays a collection of fake images
    """
    fig, ax = plt.subplots(figsize=(32,32))
    ax.set_xticks([]); ax.set_yticks([])
    images = images.to("cpu")
    images = denormalize(images[:num_to_show], *train_stats)
    ax.imshow(make_grid(images, nrow=8).permute(1, 2, 0))

def transform_image(img: Image):
    """
    Intended to transform images from CMYK -> RGB
    Overlays image on a plain, black background to remove transparent pixels
    """
    new_img = Image.new("RGBA", img.size, "BLACK")
    new_img.paste(img, (0, 0), img)
    new_img = new_img.convert("RGB")
    return new_img

def save_samples(G: nn.Module, index: int, x: Tensor):
    fake_images = denormalize(to_device(G(x), device="cpu"), *train_stats)
    filename = join(fake_dir, "generated-images-{0:0=4d}.png".format(index))
    print(f"Saving {filename}")
    save_image(fake_images[:64], filename, nrow=8)

def make_video(fps=30):
    vid_fname = "pokeGAN.avi"

    files = [join(fake_dir, f) for f in os.listdir(fake_dir) if 'generated' in f]
    files.sort()

    out = cv2.VideoWriter(vid_fname, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (1042, 1042))
    [out.write(cv2.imread(fname)) for fname in files]
    out.release()

def denormalize(images, means, stds):
    means = torch.tensor(means).reshape(1, 3, 1, 1)
    stds = torch.tensor(stds).reshape(1, 3, 1, 1)
    return images * stds + means

In [None]:
def parse_history():
    with open(join(os.getcwd(), "history", "history.json"), 'r') as f:
        history = json.load(f)["history"]

    epochs = np.arange(len(history))
    losses_d = []
    losses_g = []
    real_scores = []
    fake_scores = []

    for epoch in history:
        loss_g, loss_d, real_score, fake_score = re.findall(r"[-+]?(?:\d*\.\d+|\d+)", epoch)
        losses_g.append(float(loss_g))
        losses_d.append(float(loss_d))
        real_scores.append(float(real_score))
        fake_scores.append(float(fake_score))

    return epochs, losses_d, losses_g, real_scores, fake_scores

def plot_results():
    epochs, losses_d, losses_g, real_scores, fake_scores = parse_history()

    plt.style.use("seaborn")
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16, 4))
    ax[0][0].title.set_text("Losses")
    ax[0][1].title.set_text("Scores")
    # Plot losses
    ax[0][0].plot(epochs, losses_d)
    ax[0][0].plot(epochs, np.poly1d(np.polyfit(epochs, losses_d, deg=2))(epochs), color="red", linestyle="--")
    ax[0][0].set_ylim([-0.2, 5])
    ax[0][0].set_ylabel("Discriminator")
    ax[1][0].plot(epochs, losses_g)
    ax[1][0].plot(epochs, np.poly1d(np.polyfit(epochs, losses_g, deg=2))(epochs), color="red", linestyle="--")
    ax[1][0].set_ylim([-0.2, 5])
    ax[1][0].set_ylabel("Generator")
    # Plot scores
    ax[0][1].plot(epochs, real_scores)
    ax[0][1].plot(epochs, np.poly1d(np.polyfit(epochs, real_scores, deg=2))(epochs), color="red", linestyle="--")
    ax[0][1].set_ylabel("Real")
    ax[1][1].plot(epochs, fake_scores)
    ax[1][1].plot(epochs, np.poly1d(np.polyfit(epochs, fake_scores, deg=2))(epochs), color="red", linestyle="--")
    ax[1][1].set_ylabel("Fake")

    plt.show()

# Get Data

In [None]:
import requests
from bs4 import BeautifulSoup
from urllib.parse import urljoin
import re
from PIL import Image
import os
from os.path import join
from tqdm import tqdm
import random
import json

In [None]:
def get_image_urls(refresh=False, seed=15):
    url = "https://www.pokemon.com/us/pokedex/"
    r = requests.get(url)
    soup = BeautifulSoup(r.content, "html.parser")
    body = soup.find("body")
    pattern = "\/us\/pokedex\/[a-zA-z]+"
    pokemon = body.find_all('a', {"href": re.compile(pattern)})
    
    if refresh or "img-urls.json" not in os.listdir(img_dir):
        img_urls = []
        for p in tqdm(pokemon):
            p_url = urljoin(url, p["href"])
            r = requests.get(p_url)
            soup = BeautifulSoup(r.content, "html.parser")
            body = soup.find("body")
            img_url = body.find("img", {"class": "active"})["src"]
            img_urls.append(img_url)

        print(f"{len(img_urls)} Pokemon images found.")
        random.seed(seed)
        random.shuffle(img_urls)

        filename = join(img_dir, "img-urls.json")
        with open(filename, 'w') as f:
            json.dump(img_urls, f)

    filename = join(img_dir, "img-urls.json")
    with open(filename, 'r') as f:
        img_urls = json.load(f)

    return img_urls

def get_images(refresh=False, seed=15):
    if not "img-urls.json" in os.listdir(img_dir):
        img_urls = get_image_urls(refresh=True, seed=seed)

    else:
        img_urls = get_image_urls(seed=seed)

    if len(os.listdir(train_dir)) == 0:
        for url in tqdm(img_urls):
            "Creating training set..."
            pID = url.split('/')[-1]
            r = requests.get(url)
            if r.status_code == 200:
                r.raw.decode_content = True

                if pID not in os.listdir(join(img_dir, "train-images")):
                    filename = join(img_dir, "train-images", pID)
                    with open(filename, 'wb') as f:
                        f.write(r.content)

            else:
                print(f"Image {pID} could not be retrieved")

    print("Finished!")

In [None]:
refresh = False
seed = 15 
get_images()

# Dataset

In [None]:
from torch.utils.data import Dataset

In [None]:
class PokemonDataset(Dataset):
    """ Pokemon images dataset """
    def __init__(self, directory: str, transform=None):
        self.directory = directory
        self.transform = transform
        self.image_paths = [format_path(join(directory, img_name)) for img_name in sorted(os.listdir(directory))]

    def __len__(self):
        return len([f for f in os.listdir(self.directory) if ".png" in f])

    def __getitem__(self, idx: int):
        path = self.image_paths[idx]
        with Image.open(path) as img:

            if self.transform:
                img = self.transform(img)
            
            return img

# Models

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        latent_size = 128

        self.net = nn.Sequential(
            # in: 3 x 128 x 128

            nn.Conv2d(3, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 128 x 64 x 64

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 256 x 32 x 32

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 512 x 16 x 16

            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 1024 x 8 x 8

            nn.Conv2d(1024, 2048, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 2048 x 4 x 4

            nn.Conv2d(2048, 1, kernel_size=4, stride=1, padding=0, bias=False),
            # out: 1 x 1 x 1

            nn.Flatten(),
            nn.Sigmoid()
        )

        if "Discriminator.pth" in os.listdir(weights_dir):
            self.load()

    def forward(self, x):
        out = self.net(x)
        return out

    def save(self):
        torch.save(self.state_dict(), join(weights_dir, "Discriminator.pth"))

    def load(self):
        assert "Discriminator.pth" in os.listdir(weights_dir), "No discriminator weights found"
        self.load_state_dict(torch.load(join(weights_dir, "Discriminator.pth")))
        print("Discriminator weights loaded successfully!")

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        latent_size = 128

        self.net = nn.Sequential(
            # in: latent_size x 1 x 1

            nn.ConvTranspose2d(latent_size, 2048, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(2048),
            nn.ReLU(True),
            # out: 2048 x 4 x 4

            nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            # out: 1024 x 8 x 8

            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # out: 512 x 16 x 16

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # out: 256 x 32 x 32

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # out: 128 x 64 x 64

            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
            # out: 3 x 128 x 128
        )

        if "Generator.pth" in os.listdir(weights_dir):
            self.load()

    def forward(self, x):
        out = self.net(x)
        return out

    def save(self):
        torch.save(self.state_dict(), join(weights_dir, "Generator.pth"))

    def load(self):
        assert "Generator.pth" in os.listdir(weights_dir), "No generator weights found"
        self.load_state_dict(torch.load(join(weights_dir, "Generator.pth")))
        print("Generator weights loaded successfully!")

# Train

In [None]:
latent_size = 128
batch_size = 256

In [None]:
def train_discriminator(D: nn.Module, G: nn.Module, images: Tensor, opt_d: torch.nn.functional, device: torch.device):
    opt_d.zero_grad()

    real_preds = D(images)
    real_targets = torch.ones(images.size(0), 1, device=device)
    real_noisy_targets = (0.7 - 1.2) * torch.rand(images.size(0), 1, device=device) + 1.2
    real_loss = F.binary_cross_entropy(real_preds, real_noisy_targets)
    real_score = torch.mean(real_preds).item()

    x = torch.randn(images.size(0), latent_size, 1, 1, device=device)
    fake_images = G(x)

    fake_preds = D(fake_images)
    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_noisy_targets = (0.0 - 0.3) * torch.rand(fake_images.size(0), 1, device=device) + 0.3
    fake_loss = F.binary_cross_entropy(fake_preds, fake_noisy_targets)
    fake_score = torch.mean(fake_preds).item()

    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()

    return loss.item(), real_score, fake_score

In [None]:
def train_generator(D: nn.Module, G: nn.Module, batch_size: int, opt_g: torch.nn.functional, device: torch.device):
    opt_g.zero_grad()

    x = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = G(x)

    preds = D(fake_images)
    targets = torch.ones(batch_size, 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)

    loss.backward()
    opt_g.step()
    
    return loss.item()

In [None]:
def fit(D: nn.Module, G: nn.Module, train_dl: DataLoader, epochs: int, lr: float, device: torch.device, start_idx=1):
    # Create backups
    shutil.copy(join(history_dir, "history.json"), join(history_dir, "history_backup.json"))

    if "Discriminator.pth" in os.listdir(weights_dir):
        shutil.copy(join(weights_dir, "Discriminator.pth"), join(weights_dir, "Discriminator-backup.pth"))

    if "Generator.pth" in os.listdir(weights_dir):
        shutil.copy(join(weights_dir, "Generator.pth"), join(weights_dir, "Generator-backup.pth"))

    with open(join(history_dir, "history.json"), 'r') as f:
        history = json.load(f)["history"]

    losses_d = []
    losses_g = []
    real_scores = []
    fake_scores = []

    opt_d = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for real_images in tqdm(train_dl):
            loss_d, real_score, fake_score = train_discriminator(D, G, real_images, opt_d, device)
            loss_g = train_generator(D, G, real_images.size(0), opt_g, device)

        losses_d.append(loss_d)
        losses_g.append(loss_g)
        real_scores.append(real_score)
        fake_scores.append(fake_score)

        epoch_results = "Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
                        epoch+1, epochs, loss_g, loss_d, real_score, fake_score)
        
        print(epoch_results)
        history.append(epoch_results[epoch_results.find(',')+2:])

        x = torch.randn(batch_size, latent_size, 1, 1, device=device)
        save_samples(G, epoch+start_idx, x)
        D.save()
        G.save()

    print("Saving results...")
    with open(join(history_dir, "history.json"), 'w') as f:
        json.dump({"history": history}, f)

    make_video()

    return history

# Run

In [None]:
import torch.nn.functional as F
from random import shuffle
import torchvision.transforms as T
from torch.utils.data import DataLoader
import shutil
from tqdm import tqdm

In [None]:
def run():
    try:
        epochs = int(input("Enter the number of epochs to train for (1-5000): "))
        assert epochs in range(1, 5001), "Enter a number between 1 and 5000"
        lr = float(input("Enter the learning rate (5e-6 - 5e-4): "))
        assert lr >= 5e-6 and lr <= 5e-4, "Enter a number between 5e-6 and 5e-4"

    except AssertionError:
        print("Restarting program...")
        run()

    device = get_default_device()
    train_ds = PokemonDataset(train_dir, transform=train_transform)
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    train_dl = DeviceDataLoader(train_dl, device)
    D = to_device(Discriminator(), device)
    G = to_device(Generator(), device)

    return fit(D, G, train_dl, epochs, lr, device, start_idx=len(os.listdir(fake_dir))+1)

In [None]:
img_size = 128
batch_size = 256
train_stats = [0.1874, 0.1779, 0.1681], [1.0, 1.0, 1.0]

train_transform = T.Compose([
    T.Lambda(lambda img: transform_image(img)),
    T.Resize(img_size),
    T.CenterCrop(img_size),
    T.RandomHorizontalFlip(0.2),
    T.RandomRotation(3, fill=0),
    T.ToTensor(),
    T.Normalize(*train_stats)
])

run()
input("Training finished, press enter to exit...")

# Visualize results

In [None]:
plot_results();

In [None]:
G = Generator()

x = torch.randn(batch_size, latent_size, 1, 1)
fake_images = G(x)
num_to_show = 64
show_fakes(fake_images, num_to_show=num_to_show);

# My results

Here is an example batch of Pokemon this model created after training for 10,000 epochs on my PC. The results are not perfect, but there are definitely some interesting samples!

<img src=example.png>