## Generative adversarial network (GAN)

The following cells contain the codes for building a generative adversarial network in Pytorch, and training it on CIFAR10. It is recommended to run this in Google Colab as training will take a really long time without GPU.

In [None]:
# Setting up
!pip install -q torch torchvision altair matplotlib pandas
!git clone -q https://github.com/afspies/icl_dl_cw2_utils
from icl_dl_cw2_utils.utils.plotting import plot_tsne
%load_ext google.colab.data_table

In [None]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive') # Outputs will be saved in your google drive

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import torch.nn.functional as F
import matplotlib.pyplot as plt


In [None]:

def denorm(x, channels=None, w=None ,h=None, resize = False):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    if resize:
        if channels is None or w is None or h is None:
            print('Number of channels, width and height must be provided for resize.')
        x = x.view(x.size(0), channels, w, h)
    return x

def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

if not os.path.exists('/content/drive/MyDrive/icl_dl_cw2/CW_GAN'):
    os.makedirs('/content/drive/MyDrive/icl_dl_cw2/CW_GAN')

GPU = True # Choose whether to use GPU
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

# Set a random seed to ensure that the results are reproducible.
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

In [None]:
# Load data
batch_size = 128  # change that

transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),),                        
])

data_dir = './datasets'

cifar10_train = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)
cifar10_test = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=batch_size)
loader_test = DataLoader(cifar10_test, batch_size=batch_size)

In [None]:
# Visualise data
samples, _ = next(iter(loader_test))

samples = samples.cpu()
samples = make_grid(denorm(samples), nrow=8, padding=2, normalize=False,
                        range=None, scale_each=False, pad_value=0)
plt.figure(figsize = (15,15))
plt.axis('off')
show(samples)

In [None]:
# Choose the number of epochs, the learning rate
# and the size of the Generator's input noise vetor.

num_epochs = 20
learning_rate = 0.0001
latent_vector_size = 100

# Other hyperparams
num_feature_maps = 128


In [None]:
# Define GAN

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(latent_vector_size, num_feature_maps*4, 4, 1, 0, bias = False),
            nn.BatchNorm2d(num_feature_maps*4),
            nn.ReLU(),
            nn.ConvTranspose2d(num_feature_maps*4, num_feature_maps*2, 4, 2, 1, bias = False),
            nn.BatchNorm2d(num_feature_maps*2),
            nn.ReLU(),
            nn.ConvTranspose2d(num_feature_maps*2, num_feature_maps, 4, 2, 1, bias = False),
            nn.BatchNorm2d(num_feature_maps),
            nn.ReLU(),
            nn.ConvTranspose2d(num_feature_maps, 3, 4, 2, 1, bias = False),
            nn.Tanh()
        )


    def forward(self, z, label = None):
        
        out = self.gen(z)

        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.dis = nn.Sequential(
            nn.Conv2d(3, num_feature_maps, 4, 2, 1, bias = False),
            nn.BatchNorm2d(num_feature_maps),
            nn.LeakyReLU(0.2),
            nn.Conv2d(num_feature_maps, num_feature_maps*2, 4, 2, 1, bias = False),
            nn.BatchNorm2d(num_feature_maps*2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(num_feature_maps*2, num_feature_maps*4, 4, 2, 1, bias = False),
            nn.BatchNorm2d(num_feature_maps*4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(num_feature_maps*4, 1, 4, 1, 0, bias = False),
            nn.Sigmoid()
        )

        
    def forward(self, x, label = None):
     
        out = self.dis(x)
        
        return out


In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
use_weights_init = True

model_G = Generator().to(device)
if use_weights_init:
    model_G.apply(weights_init)
params_G = sum(p.numel() for p in model_G.parameters() if p.requires_grad)
print("Total number of parameters in Generator is: {}".format(params_G))
print(model_G)
print('\n')

model_D = Discriminator().to(device)
if use_weights_init:
    model_D.apply(weights_init)
params_D = sum(p.numel() for p in model_D.parameters() if p.requires_grad)
print("Total number of parameters in Discriminator is: {}".format(params_D))
print(model_D)
print('\n')

print("Total number of parameters is: {}".format(params_G + params_D))

In [None]:
def loss_function(out, label):
    loss = F.binary_cross_entropy(out, label)
    return loss

In [None]:
# setup optimizer
beta1 = 0.5
optimizerD = torch.optim.Adam(model_D.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(model_G.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
real_label = 1
fake_label = 0

In [None]:
# record the losses
train_losses_D = []
train_losses_D_real = []
train_losses_D_fake = []
train_losses_G = []


for epoch in range(num_epochs):

    train_loss_D = 0
    train_loss_D_real = 0
    train_loss_D_fake = 0
    train_loss_G = 0

    for i, data in enumerate(loader_train, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################device
        
        # train with real
        model_D.zero_grad()

        real_cpu = data[0].to(device = device)
            # create the labels
        batch = real_cpu.shape[0]
        label = torch.full((batch,), real_label, dtype = torch.float, device = device)
        out = model_D(real_cpu).view(-1)
            # calculate loss and propagate back
        real_error = loss_function(out, label)
        real_error.backward()

        D_x = out.mean().item()


        # train with fake
        fixed_noise = torch.randn(batch, latent_vector_size, 1, 1, device=device)
            # generate fake samples
        fake = model_G(fixed_noise)
        label.fill_(fake_label)
        out = model_D(fake.detach()).view(-1)

        fake_error = loss_function(out, label)
        fake_error.backward()

        D_G_z1 = out.mean().item()

        errD = real_error + fake_error  

        optimizerD.step()     
            # keep track of the losses for plotting
        train_loss_D += errD.item()
        train_loss_D_real += real_error.item()
        train_loss_D_fake += fake_error.item()


        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        model_G.zero_grad()
            # generate labels to be real
        label.fill_(real_label)
        out = model_D(fake).view(-1)

            # calculate loss and propagate back
        errG = loss_function(out, label)
        errG.backward()

        D_G_z2 = out.mean().item()
            # update optimiser
        optimizerG.step()
            # keep track of the loss for plotting
        train_loss_G += errG.item()


        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, i, len(loader_train),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

    if epoch == 0:
        save_image(denorm(real_cpu.cpu()).float(), '/content/drive/MyDrive/GAN/real_samples.png')
    with torch.no_grad():
        fake = model_G(fixed_noise)
        save_image(denorm(fake.cpu()).float(), '/content/drive/MyDrive/GAN/fake_samples_epoch_%03d.png' % epoch)
    train_losses_D.append(train_loss_D / len(loader_train))
    train_losses_D_real.append(train_loss_D_real / len(loader_train))
    train_losses_D_fake.append(train_loss_D_fake / len(loader_train))
    train_losses_G.append(train_loss_G / len(loader_train))


   
    
# save  models 
# if your discriminator/generator are conditional you'll want to change the inputs here
torch.jit.save(torch.jit.trace(model_G, (fixed_noise)), '/content/drive/MyDrive/GAN/GAN_G_model.pth')
torch.jit.save(torch.jit.trace(model_D, (fake)), '/content/drive/MyDrive/GAN/GAN_D_model.pth')


In [None]:
# Generator samples

input_noise = torch.randn(100, latent_vector_size, 1, 1, device=device)
with torch.no_grad():
    # visualize the generated images
    generated = model_G(input_noise).cpu()
    generated = make_grid(denorm(generated)[:100], nrow=10, padding=2, normalize=False, 
                        range=None, scale_each=False, pad_value=0)
    plt.figure(figsize=(15,15))
    save_image(generated,'/content/drive/MyDrive/GAN/final.png')
    show(generated) # note these are now class conditional images columns rep classes 1-10

it = iter(loader_test)
sample_inputs, _ = next(it)
fixed_input = sample_inputs[0:64, :, :, :]
# visualize the original images of the last batch of the test set for comparison
img = make_grid(denorm(fixed_input), nrow=8, padding=2, normalize=False,
                range=None, scale_each=False, pad_value=0)
plt.figure(figsize=(15,15))
show(img)

In [None]:
# Plotting loss

iterations = list(range(len(train_losses_D)))

fig = plt.figure(figsize=(8,8))

ax1 = fig.add_subplot(1,2,1)
ax1.plot(iterations, train_losses_D, label = "d_loss")
ax1.plot(iterations, train_losses_D_real, label = "d_loss_real")
ax1.plot(iterations, train_losses_D_fake, label = "d_loss_fake")
ax1.set_yscale('log')
ax1.set_title("discriminator")
ax1.legend(loc='best')

ax2 = fig.add_subplot(1,2,2)
ax2.plot(iterations, train_losses_G, label = "g_loss")
ax2.set_yscale('log')
ax2.set_title("generator")
ax2.legend(loc='best')
plt.savefig("/content/drive/MyDrive/GAN/loss.png")