In [1]:
import os, sys; sys.path.append("../src")
import torch
import wandb

import torch.nn as nn
import torch.optim as optim

from tqdm import trange
from torchvision import transforms, models
from torch.utils.data import DataLoader
from utils.image_dataset import ImageDataset
from utils.find_size import model_output, decoder_input
from torchsummary import summary

# GAN

In [2]:
wandb.init(project="comic-character-generation", entity="lionel-polanski", name="Custom Architecture GAN ", dir="..")

wandb: Currently logged in as: kamwithk (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.22 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


## Parameters

In [3]:
size = 64
in_channels = 3
batch_size = 32

epochs = 1000
hidden_dims = [512, 256, 128, 64, 32]
latent_vector = hidden_dims[0]

## Data

In [4]:
decoder_size = decoder_input(hidden_dims, in_channels, size)[0][-1]

In [5]:
transform = transforms.Compose([
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [6]:
dataset = ImageDataset("../data/superhero_white_background", transform, decoder_size, latent_vector)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

## Model Architecture

In [7]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [8]:
class Generator(nn.Module):
    def __init__(self, size=1, in_channels=3, hidden_dims=[512, 256, 128, 64, 32]):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(*[
            nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU()
            ) for in_channels, out_channels in zip(hidden_dims[:-1], hidden_dims[1:])
        ])
        
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, input):
        return self.final_layer(self.main(input))

In [9]:
class Discriminator(nn.Module):
    def __init__(self, hidden_dims=[32, 64, 128, 256, 512], final_conv_kernel=1):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(*[
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU()
            ) for in_channels, out_channels in zip([3, *hidden_dims[:-1]], hidden_dims)
        ])
        
        self.final_layer = nn.Sequential(
            nn.Conv2d(hidden_dims[-1], 1, final_conv_kernel),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.final_layer(self.main(input))

In [10]:
class ResnetDiscriminator(nn.Module):
    def __init__(self, model=models.resnet18(pretrained=True), size=128):
        super(ResnetDiscriminator, self).__init__()
        
        resnet_output_shape = model_output(nn.Sequential(*list(model.children())[:-2]), size=size)
        final_conv_kernel, final_layer_input = resnet_output_shape[0][-1], resnet_output_shape[0][1]
        
        self.resnet = model
        self.resnet.avgpool, self.resnet.fc = nn.Conv2d(final_layer_input, 1, final_conv_kernel), nn.Sigmoid()

    def forward(self, input):
        return self.resnet(input)

## Model Creation

In [11]:
generator = Generator(decoder_size, in_channels, hidden_dims).to("cuda")
# discriminator = models.resnet18(pretrained=True).to("cuda")
# discriminator = Discriminator(hidden_dims[::-1], decoder_size).to("cuda")
discriminator = ResnetDiscriminator(models.resnet101(pretrained=True), size).to("cuda")

generator = generator.apply(weights_init)
# discriminator = discriminator.apply(weights_init)

wandb.watch(generator);
wandb.watch(discriminator);

## Optimisers

In [12]:
criterion = nn.BCELoss()

generator_optimiser = optim.Adam(generator.parameters(), lr=0.01, betas=(0.5, 0.999))
discriminator_optimiser = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

## Training

In [13]:
for epoch in trange(epochs):
    for noise, imgs in dataloader:
        # Labels
        real_label = torch.full((imgs.size(0),), 1., dtype=torch.float, device="cuda")
        fake_label = torch.full((imgs.size(0),), 0., dtype=torch.float, device="cuda")
        
        noise, imgs = noise.to("cuda"), imgs.to("cuda")
        
        discriminator_optimiser.zero_grad()
        
        # Discriminator
        real_loss = criterion(discriminator(imgs).view(-1), real_label)
        fake_loss = criterion(discriminator(generator(noise).detach()).view(-1), fake_label)
        
        (real_loss + fake_loss).backward()
        
        discriminator_optimiser.step()
        
        # Generator
        generator_optimiser.zero_grad()
        
        generator_loss = criterion(discriminator(generator(noise)).view(-1), real_label)
        generator_loss.backward()
        
        generator_optimiser.step()
        
        # Log Stats
        wandb.log({
            "real_loss": real_loss, "fake_loss": fake_loss,
            "generator_loss": generator_loss
        })
        
    # LOG SAMPLE IMAGES
    generations = generator(torch.randn(5, *dataset[0][0].shape, device="cuda"))
    wandb.log({"generations": [wandb.Image(sample_generation) for sample_generation in generations]})

 10%|█████                                               | 98/1000 [06:54<1:03:33,  4.23s/it]


KeyboardInterrupt: 

In [None]:
torch.save(generator.state_dict(), os.path.join(wandb.run.dir, "generator_model.pt"))
torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, "discriminator_model.pt"))