# PokeGAN

In [1]:
# Execute this to save new versions of the notebook
import jovian
jovian.commit(filename="PokeGAN.ipynb")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[jovian] Updating notebook "kodlak15/pokegan" on https://jovian.ai/[0m
[jovian] Committed successfully! https://jovian.ai/kodlak15/pokegan[0m


'https://jovian.ai/kodlak15/pokegan'

In [2]:
from PIL import Image
import os
from os import path
from utils import *
from device_mgmt import *

%matplotlib inline

In [3]:
root_dir = os.getcwd()
img_dir = format_path(join(root_dir, "images"))
train_dir = format_path(join(img_dir, "train-images"))
fake_dir = join(root_dir, "fakes")
weights_dir = join(root_dir, "weights")
history_dir = join(root_dir, "history")

# Get images

In [4]:
from PokeScraper import get_images
get_images()

Finished!


# Create datasets and dataloaders

In [5]:
from torch.utils.data import DataLoader
import torchvision.transforms as T
from dataset import PokemonDataset

In [6]:
train_stats = [0.8577, 0.8482, 0.8384], [0.0579, 0.0565, 0.0608]
img_size = 128

train_transform = T.Compose([
    T.Lambda(lambda img: transform_image(img)),
    T.ColorJitter(brightness=0, contrast=0, saturation=(1.0, 1.5), hue=(-0.15, 0.15)),
    T.Resize(img_size),
    T.CenterCrop(img_size),
    T.RandomHorizontalFlip(0.2),
    T.RandomRotation(3, fill=0),
    T.ToTensor()
])

# T.Normalize(*train_stats)

train_ds = PokemonDataset(train_dir, transform=train_transform)
print(f"There are {len(train_ds)} images in the train dataset")

There are 905 images in the train dataset


In [7]:
batch_size = 64

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# Preview data

In [8]:
for batch in train_dl:
    print(f"Batch tensor shape: {batch.shape}")
    break

show_images(train_dl);

Batch tensor shape: torch.Size([64, 3, 128, 128])


TypeError: show_images() takes 1 positional argument but 2 were given

# Set up devices

In [None]:
device = get_default_device()
print(f"Current device: {device}")

In [None]:
train_dl = DeviceDataLoader(train_dl, device)
clear_cache_and_get_info(device)

# Build model

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from models import Discriminator, Generator

In [None]:
D = to_device(Discriminator(), device)
G = to_device(Generator(), device)

# Train model

How training works: 
- The discriminator is trained before the generator. It computes the scores (% predicted to be real) and loss values for real and fake images based on predictions made by the generator. 
- The generator is trained second. It creates fake images, then feeds them to the (new) discriminator, returning a loss value. 
- Since the discriminator is trained using the generator from the previous epoch, you can sometimes infer how well the fake images will score by the generator loss from the previous epoch. 
    - ie: if the generator loss last epoch was exceptionally low, you can generally expect the discriminator to struggle more to correctly label the image fake.
- The discriminator eventually regains the upper hand, challenging the generator to create better fakes. 

In [None]:
from train import fit

In [None]:
lr = 1e-4
epochs = 100

In [None]:
# history = fit(D, G, train_dl, epochs, lr, device, start_idx=len(os.listdir(fake_dir))+1)

In [None]:
latent_size = 128
x = torch.randn(1, latent_size, 1, 1, device="cpu")
G = to_device(G, device="cpu")
fake = G(x)

In [None]:
with torch.no_grad():
    show_images(fake)