In [1]:
import os, sys; sys.path.append("../src")
import torch
import wandb

import torch.nn as nn
import torch.optim as optim

from tqdm import trange
from models.gan_trainer import GANTrainer
from models.custom_generator import Generator
from models.custom_discriminator import Discriminator
from models.resnet_discriminator import ResnetDiscriminator
from torchvision import transforms, models
from torch.utils.data import DataLoader
from utils.image_dataset import ImageDataset
from torchsummary import summary

# GAN

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SIZE, CHANNELS = 64, 3
NOISE_SIZE, LATENT_DIMS = 1, 128
BATCH_SIZE = 64

EPOCHS = 300
MAX_IMAGES, LOG_IMAGES_EVERY = None, 200
GENERATOR_HIDDEN_DIMS = [800, 400, 200, 100, CHANNELS]
DISCRIMINATOR_HIDDEN_DIMS = [64, 128, 256, 512, NOISE_SIZE]

GENERATOR_LR, DISCRIMINATOR_LR = 0.0002, 0.0002

In [3]:
wandb.init(project="gan-demo", name=f"Abstract Art {SIZE} Standard Test", mode="online")
# wandb.init(project="comic-character-generation", entity="lionel-polanski", name=f"RaGAN {SIZE} PDSH", dir="..", mode="online")

wandb: Currently logged in as: kamwithk (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.30 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


## Data

In [4]:
transform = transforms.Compose([
    transforms.Resize(SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # ImageNet values
])

In [5]:
dataset = ImageDataset("../data/cleaned", MAX_IMAGES, transform, NOISE_SIZE, LATENT_DIMS)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

## Model Creation

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
generator = Generator(LATENT_DIMS, GENERATOR_HIDDEN_DIMS).to(DEVICE)
discriminator = Discriminator(CHANNELS, DISCRIMINATOR_HIDDEN_DIMS).to(DEVICE)
# discriminator = ResnetDiscriminator(models.resnet18(pretrained=False), SIZE).to(DEVICE)

generator = generator.apply(weights_init)
discriminator = discriminator.apply(weights_init)

wandb.watch(generator);
wandb.watch(discriminator);

In [8]:
generator

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(128, 800, kernel_size=(4, 4), stride=(1, 1))
    (1): Sequential(
      (0): ConvTranspose2d(800, 400, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(400, 200, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(200, 100, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(100, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  )
  (final_layer): Sequential(
    (0): Tanh()
  )
)

In [9]:
summary(generator, (LATENT_DIMS, NOISE_SIZE, NOISE_SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 800, 4, 4]       1,639,200
   ConvTranspose2d-2            [-1, 400, 8, 8]       5,120,000
       BatchNorm2d-3            [-1, 400, 8, 8]             800
              ReLU-4            [-1, 400, 8, 8]               0
   ConvTranspose2d-5          [-1, 200, 16, 16]       1,280,000
       BatchNorm2d-6          [-1, 200, 16, 16]             400
              ReLU-7          [-1, 200, 16, 16]               0
   ConvTranspose2d-8          [-1, 100, 32, 32]         320,000
       BatchNorm2d-9          [-1, 100, 32, 32]             200
             ReLU-10          [-1, 100, 32, 32]               0
  ConvTranspose2d-11            [-1, 3, 64, 64]           4,800
             Tanh-12            [-1, 3, 64, 64]               0
Total params: 8,365,400
Trainable params: 8,365,400
Non-trainable params: 0
---------------------------

In [10]:
discriminator

Discriminator(
  (main): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): 

In [11]:
summary(discriminator, (CHANNELS, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           3,072
       BatchNorm2d-2           [-1, 64, 32, 32]             128
         LeakyReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4          [-1, 128, 16, 16]         131,072
       BatchNorm2d-5          [-1, 128, 16, 16]             256
         LeakyReLU-6          [-1, 128, 16, 16]               0
            Conv2d-7            [-1, 256, 8, 8]         524,288
       BatchNorm2d-8            [-1, 256, 8, 8]             512
         LeakyReLU-9            [-1, 256, 8, 8]               0
           Conv2d-10            [-1, 512, 4, 4]       2,097,152
      BatchNorm2d-11            [-1, 512, 4, 4]           1,024
        LeakyReLU-12            [-1, 512, 4, 4]               0
           Conv2d-13              [-1, 1, 1, 1]           8,192
Total params: 2,765,696
Trainable param

## Optimisers

In [12]:
generator_optimiser = optim.Adam(generator.parameters(), lr=GENERATOR_LR, betas=(0.5, 0.999))
discriminator_optimiser = optim.Adam(discriminator.parameters(), lr=DISCRIMINATOR_LR, betas=(0.5, 0.999))

In [13]:
gan_trainer = GANTrainer(generator, discriminator, generator_optimiser, discriminator_optimiser, relavistic=False)

## Training

In [14]:
denormalise = transforms.Compose([
  transforms.Normalize(mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]),
  transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.])
])

generations = lambda num_samples : denormalise(gan_trainer.generator(torch.randn(num_samples, *dataset[0][0].shape, device=DEVICE)))
wandb_images = lambda images : {"generations": [wandb.Image(image) for image in images]}

In [None]:
for epoch in trange(EPOCHS):
    for index, (noise, imgs) in enumerate(dataloader):
        # Labels
        real_label = torch.full((imgs.size(0),), 1., dtype=torch.float, device=DEVICE)
        fake_label = torch.full((imgs.size(0),), 0., dtype=torch.float, device=DEVICE)
        
        noise, imgs = noise.to(DEVICE), imgs.to(DEVICE)
        
        real_loss, fake_loss = gan_trainer.train_discriminator(noise, imgs, real_label, fake_label)
        generator_loss = gan_trainer.train_generator(noise, imgs, real_label, fake_label)
        
        # Log Stats
        wandb.log({
            "real_loss": real_loss, "fake_loss": fake_loss,
            "generator_loss": generator_loss
        })
        
        # LOG SAMPLE IMAGES AFTER LOG_IMAGES_EVERY STEPS
        if index % LOG_IMAGES_EVERY == 0: wandb.log(wandb_images(generations(5)))
        
    # LOG SAMPLE IMAGES AFTER EACH EPOCH
    wandb.log(wandb_images(generations(5)))

In [None]:
torch.save(generator.state_dict(), os.path.join(wandb.run.dir, "generator_model.pt"))
torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, "discriminator_model.pt"))