# Sources

In [None]:
# Overall code structure: based on https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# Dataset loading: based on https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
# Conditional GAN aspect: based on https://github.com/znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN
# Wasserstein aspect: based on https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py
# GAN stability improvements: based on https://machinelearningmastery.com/how-to-train-stable-generative-adversarial-networks/

# Imports

In [None]:
import os
import math
import random
import numpy as np
import cv2 as cv
import json
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torch.autograd import Variable
from skimage import io
from torch.utils.data import Dataset

# Hyperparamaters

In [None]:
# Root directory for dataset
dataroot = "training_data/"
imageroot = dataroot+"layout_images"
labelroot = dataroot+"layout_labels_all_4_text_proportion_classes.json"

# The number of classes
num_classes = 4

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 64

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 32

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 128

# Size of feature maps in discriminator
ndf = 128

# Number of training epochs
num_epochs = 5000

# Learning rate for optimizers
lr = 0.00005

# Clip value (wasserstein)
clip_value = 0.01

# Number of times to update discriminator before we update generator
n_critic = 5

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# Set random seed for reproducibility
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

print(device)

cuda


# Google drive integration

In [None]:
# optional Google drive integration - this will allow you to save and resume training, and may speed up redownloading the dataset
from google.colab import drive
drive.mount('/content/drive')
os.chdir("drive/My Drive/Uni/L3/Project/Coded")

MessageError: ignored

# Load dataset

In [None]:
# Load dataset
class MagazineDataset(Dataset):
    # Source: Based on https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
    def __init__(self, transform=None):
        self.transform = transform

        self.image_files = os.listdir(imageroot)

        with open(labelroot) as infile:
          self.labels = json.load(infile)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get paths for files
        file_name = self.image_files[idx][:-4]
        img_path = os.path.join(imageroot, file_name+".png")

        # Read image and label
        image = io.imread(img_path)
        if self.transform:
            image = self.transform(image)

        label = self.labels[file_name]
        label = label["textProportion"].index(1)

        return image, label

transform = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Resize(image_size),
                torchvision.transforms.CenterCrop(image_size),
                torchvision.transforms.Normalize(
                    (0.5), (0.5)),
            ])

dataset = MagazineDataset(transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

# Example training images

In [None]:
# Plot some training images
examples = enumerate(dataloader)
batch_idx, (example_data, example_targets) = next(examples)

plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(torchvision.utils.make_grid(example_data.to("cpu")[:16], padding=2, normalize=True).to("cpu"),(1,2,0)))

for i in range(16):
  print(example_targets[i])

# Network weights initialisation function

In [None]:
def weights_init(m):
  # Source: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
      nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
      nn.init.normal_(m.weight.data, 1.0, 0.02)
      nn.init.constant_(m.bias.data, 0)

# Generator

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.deconv_z = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(0.2),
            nn.Dropout(0.4),
        )

        self.deconv_label = nn.Sequential(
            nn.ConvTranspose2d(num_classes, ngf*4, 4, 1, 0),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(0.2),
            nn.Dropout(0.4),
        )

        self.main = nn.Sequential(
            nn.ConvTranspose2d((ngf*4)*2, ngf*4, 4, 2, 1),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(0.2),
            nn.Dropout(0.4),

            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(0.2),

            nn.ConvTranspose2d(ngf*2, nc, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # Create label embedding
        z = self.deconv_z(z)
        labels = self.deconv_label(labels)
        x = torch.cat([z, labels], dim = 1)

        return self.main(x)

# Create the generator
netG = Generator().to(device)

# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.02.
netG.apply(weights_init)

Generator(
  (deconv_z): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.4, inplace=False)
  )
  (deconv_label): Sequential(
    (0): ConvTranspose2d(4, 512, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.4, inplace=False)
  )
  (main): Sequential(
    (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.4, inplace=False)
    (4): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
   

# Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.embed_x = nn.Sequential(
            nn.Conv2d(nc, ndf//2, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4)
        )

        self.embed_labels = nn.Sequential(
            nn.Conv2d(num_classes, ndf//2, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4)
        )

        self.main = nn.Sequential(
            nn.Conv2d((ndf//2)*2, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),

            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),

            nn.Conv2d(ndf*4, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x, labels):
        # Create label embedding
        x = self.embed_x(x)
        labels = self.embed_labels(labels)
        x = torch.cat([x, labels], dim = 1)

        return self.main(x)

# Create the Discriminator
netD = Discriminator().to(device)

# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.2.
netD.apply(weights_init)

Discriminator(
  (embed_x): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.4, inplace=False)
  )
  (embed_labels): Sequential(
    (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.4, inplace=False)
  )
  (main): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout(p=0.4, inplace=False)
    (4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.2, inplace=True)
    (7): Dropout(p=0.4, inplace=False)
    (8): Conv2d(512, 1, ke

# Checkpoint functions

In [None]:
def save_checkpoint(epoch):
  checkpoint = {
      'netG': netG.state_dict(),
      'netD': netD.state_dict(),
      'optimizerG': optimizerG.state_dict(),
      'optimizerD': optimizerD.state_dict()
  }

  torch.save(checkpoint, 'checkpoints/text_proportion_wasserstein_checkpoint'+str(epoch)+'.pth')

def load_checkpoint(epoch):
    checkpoint = torch.load('checkpoints/text_proportion_wasserstein_checkpoint'+str(epoch)+'.pth', map_location=torch.device(device))
    netG.load_state_dict(checkpoint['netG'])
    netD.load_state_dict(checkpoint['netD'])
    optimizerG.load_state_dict(checkpoint['optimizerG'])
    optimizerD.load_state_dict(checkpoint['optimizerD'])

# Training initialisations

In [None]:
# Initialize BCELoss function
criterion = nn.BCELoss()


# Setup Adam optimizers for both G and D
optimizerG = torch.optim.RMSprop(netG.parameters(), lr=lr)
optimizerD = torch.optim.RMSprop(netD.parameters(), lr=lr)


# Preprocessed labels
onehot = torch.zeros(num_classes, num_classes)
onehot = onehot.scatter_(1, torch.LongTensor(list(range(num_classes))).view(num_classes,1), 1).view(num_classes, num_classes, 1, 1)
fill = torch.zeros([num_classes, num_classes, image_size, image_size])
for i in range(num_classes):
    fill[i, i, :, :] = 1

one = torch.FloatTensor([1]).to(device)
mone = (one * -1).to(device)

def random_labels(size, min,max):
  # Generate random labels
  vals = torch.rand((size,), dtype=torch.float, device=device) * (max - min) + min
  return vals

# Fixed noise & label
temp_z0_ = torch.randn(nz).repeat(4)
temp_z1_ = torch.randn(nz).repeat(4)
temp_z2_ = torch.randn(nz).repeat(4)
temp_z3_ = torch.randn(nz).repeat(4)
fixed_z_ = torch.cat([temp_z0_, temp_z1_, temp_z2_, temp_z3_], 0)
fixed_z_ = fixed_z_.view(-1, nz, 1, 1)

fixed_y_ = torch.tensor([0,1,2,3]).repeat(4)
fixed_y_label_ = onehot[fixed_y_]

fixed_noise, fixed_labels = Variable(fixed_z_.to(device), volatile=True), Variable(fixed_y_label_.to(device), volatile=True)

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

# Training loop

In [None]:
print("Starting Training Loop...")

# Load network at checkpoint
checkpoint = 0
if(checkpoint):
  load_checkpoint(checkpoint)

# For each epoch
for epoch in range(checkpoint+1, num_epochs+checkpoint):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        # Get real images (x_) and real labels (y_)
        x_ = data[0].to(device)
        y_ = data[1].to(device)

        # Get batch size
        b_size = x_.size(0)

        # Image labels (wasserstein)
        real_img_labels = torch.full((b_size,), -1.0, dtype=torch.float, device=device)
        false_img_labels = torch.full((b_size,), 1.0, dtype=torch.float, device=device)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################

        ## Train with all-real batch
        # Forward pass real batch through D
        netD.zero_grad()
        y_fill_ = fill[y_]
        x_, y_fill_ = Variable(x_.to(device)), Variable(y_fill_.to(device))
        outputReal = netD(x_, y_fill_).view(-1)

        d_loss_real = outputReal.mean(0).view(1)
        d_loss_real.backward(one)

        ## Train with all-fake batch
        # Generate batch of latent vectors
        z_ = torch.randn((b_size, nz)).view(-1, nz, 1, 1)
        y_ = (torch.rand(b_size, 1) * num_classes).type(torch.LongTensor).squeeze()
        y_label_ = onehot[y_]
        y_fill_ = fill[y_]
        z_, y_label_, y_fill_ = Variable(z_.to(device)), Variable(y_label_.to(device)), Variable(y_fill_.to(device))

        # Generate fake image batch with G
        fake = netG(z_, y_label_)

        # Classify all fake batch with D
        outputFake = netD(fake.detach(), y_fill_).view(-1)

        d_loss_fake = outputFake.mean(0).view(1)
        d_loss_fake.backward(mone)

        # Calculate discriminator loss
        errD = d_loss_fake - d_loss_real
        optimizerD.step()

        # Clip weights of discriminator
        for p in netD.parameters():
            p.data.clamp_(-clip_value, clip_value)

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################

        # Update the generator every n_critic iterations
        if i % n_critic == 0:
            netG.zero_grad()
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake, y_fill_).view(-1)

            errG = output.mean().mean(0).view(1)
            errG.backward(one)
            optimizerG.step()

        # Output training stats
        if epoch % 1 == 0 and i == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'% (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item()))
            
            fake_visu = netG(fixed_noise, fixed_labels)
            plt.gca().get_xaxis().set_visible(False)
            plt.gca().get_yaxis().set_visible(False)
            plt.imshow(torchvision.utils.make_grid(fake_visu, normalize=True, nrow=4).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
            plt.savefig("anim/"+str(iters)+".png")
            plt.show()

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        iters += 1

    # Save checkpoint every N epochs
    if(epoch % 50 == 0 and epoch != 0):
      # Save network checkpoi nt
      save_checkpoint(epoch)