In [1]:
import os, sys; sys.path.append("../src")
import torch
import wandb

import torch.nn as nn
import torch.optim as optim

from tqdm import trange
from torchvision import transforms, models
from torch.utils.data import DataLoader
from utils.image_dataset import ImageDataset
from torchsummary import summary

# GAN

In [2]:
wandb.init(project="comic-character-generation", entity="lionel-polanski", name="Tutorial GAN", dir="..")

wandb: Currently logged in as: kamwithk (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.22 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


## Parameters

In [3]:
epochs = 200
latent_vector = 100
ngf = 100

ndf = 64

## Data

In [4]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [5]:
dataset = ImageDataset("../data/superhero", transform, 1, latent_vector=latent_vector)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

## Model Architecture
Note: This model architecture is from the [official PyTorch GAN tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(latent_vector, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            # input is (num_channels) x 64 x 64
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [9]:
generator = Generator().to("cuda")
discriminator = Discriminator().to("cuda")

generator = generator.apply(weights_init)
discriminator = discriminator.apply(weights_init)

wandb.watch(generator);
wandb.watch(discriminator);

## Optimisers

In [10]:
criterion = nn.BCELoss()

generator_optimiser = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimiser = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

## Training

In [11]:
for epoch in trange(epochs):
    for noise, imgs in dataloader:
        # Labels
        real_label = torch.full((imgs.size(0),), 1., dtype=torch.float, device="cuda")
        fake_label = torch.full((imgs.size(0),), 0., dtype=torch.float, device="cuda")
        
        noise, imgs = noise.to("cuda"), imgs.to("cuda")
        
        discriminator_optimiser.zero_grad()
        
        # Discriminator
        real_loss = criterion(discriminator(imgs).view(-1), real_label)
        fake_loss = criterion(discriminator(generator(noise).detach()).view(-1), fake_label)
        
        (real_loss + fake_loss).backward()
        
        discriminator_optimiser.step()
        
        # Generator
        generator_optimiser.zero_grad()
        
        generator_loss = criterion(discriminator(generator(noise)).view(-1), real_label)
        generator_loss.backward()
        
        generator_optimiser.step()
        
        # Log Stats
        wandb.log({
            "real_loss": real_loss, "fake_loss": fake_loss,
            "generator_loss": generator_loss
        })
        
    # LOG SAMPLE IMAGES
    generations = generator(torch.randn(5, *dataset[0][0].shape, device="cuda"))
    wandb.log({"generations": [wandb.Image(sample_generation) for sample_generation in generations], "epoch": epoch})

 18%|█████████████████████████████████████████                                                                                                                                                                                     | 37/200 [03:00<13:16,  4.88s/it]


KeyboardInterrupt: 

In [None]:
torch.save(generator.state_dict(), os.path.join(wandb.run.dir, "generator_model.pt"))
torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, "discriminator_model.pt"))