In [1]:
import os, sys; sys.path.append("../src")
import torch
import wandb

import torch.nn as nn
import torch.optim as optim

from tqdm import trange
from models.gan_trainer import GANTrainer
from models.custom_generator import Generator
from models.custom_discriminator import Discriminator
from models.resnet_discriminator import ResnetDiscriminator
from torchvision import transforms, models
from torch.utils.data import DataLoader
from utils.image_dataset import ImageDataset
from torchsummary import summary

# GAN

In [2]:
wandb.init(project="comic-character-generation", entity="lionel-polanski", name="RaGAN 64 White Backgrounds", dir="..", mode="online")

wandb: Currently logged in as: kamwithk (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.22 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


## Parameters

In [3]:
size = 64
channels = 3
batch_size = 32

epochs = 1000
hidden_dims = [512, 256, 128, 64, 32]
noise_size, latent_dims = 1, 128

## Data

In [4]:
transform = transforms.Compose([
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [5]:
dataset = ImageDataset("../data/superhero_white_background", transform, noise_size, latent_dims)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

## Model Creation

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
generator = Generator(noise_size, size, latent_dims, channels, hidden_dims).to("cuda")
discriminator = Discriminator(size, channels, hidden_dims[::-1]).to("cuda")
# discriminator = ResnetDiscriminator(models.resnet18(pretrained=False), size).to("cuda")

generator = generator.apply(weights_init)
discriminator = discriminator.apply(weights_init)

wandb.watch(generator);
wandb.watch(discriminator);

In [8]:
# summary(discriminator, (3, 64, 64))

## Optimisers

In [9]:
generator_optimiser = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
discriminator_optimiser = optim.Adam(discriminator.parameters(), lr=0.0004, betas=(0.5, 0.999))

In [10]:
gan_trainer = GANTrainer(generator, discriminator, generator_optimiser, discriminator_optimiser, relavistic="average")

## Training

In [11]:
# LOG SAMPLE IMAGES BEFORE TRAINING STARTS
generations = gan_trainer.generator(torch.randn(5, *dataset[0][0].shape, device="cuda"))
wandb.log({"generations": [wandb.Image(sample_generation) for sample_generation in generations]})

In [None]:
for epoch in trange(epochs):
    for index, (noise, imgs) in enumerate(dataloader):
        # Labels
        real_label = torch.full((imgs.size(0),), 1., dtype=torch.float, device="cuda")
        fake_label = torch.full((imgs.size(0),), 0., dtype=torch.float, device="cuda")
        
        noise, imgs = noise.to("cuda"), imgs.to("cuda")
        
        real_loss, fake_loss = gan_trainer.train_discriminator(noise, imgs, real_label, fake_label)
        generator_loss = gan_trainer.train_generator(noise, imgs, real_label, fake_label)
        
        # Log Stats
        wandb.log({
            "real_loss": real_loss, "fake_loss": fake_loss,
            "generator_loss": generator_loss
        })
        
    # LOG SAMPLE IMAGES
    generations = gan_trainer.generator(torch.randn(5, *dataset[0][0].shape, device="cuda"))
    wandb.log({"generations": [wandb.Image(sample_generation) for sample_generation in generations]})

In [None]:
torch.save(generator.state_dict(), os.path.join(wandb.run.dir, "generator_model.pt"))
torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, "discriminator_model.pt"))