#Generative Adversarial Networks (GANs)

In this notebook, you will train a simple GAN on a subset of the CelebA dataset.

This notebook is based upon PyTorch's DCGAN Tutorial, that can be found [here](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html).

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image

## Data Preparation

In [2]:
# Load the zip file available in Moodle and unzip it
!tar -xvf data/subset_celeba.zip


x subset_celeba/
x subset_celeba/019359.jpg
x subset_celeba/174823.jpg
x subset_celeba/082179.jpg
x subset_celeba/191111.jpg
x subset_celeba/032862.jpg
x subset_celeba/141126.jpg
x subset_celeba/044738.jpg
x subset_celeba/193074.jpg
x subset_celeba/098653.jpg
x subset_celeba/026803.jpg
x subset_celeba/060856.jpg
x subset_celeba/061590.jpg
x subset_celeba/010369.jpg
x subset_celeba/152628.jpg
x subset_celeba/154259.jpg
x subset_celeba/105016.jpg
x subset_celeba/038383.jpg
x subset_celeba/161584.jpg
x subset_celeba/089038.jpg
x subset_celeba/053522.jpg
x subset_celeba/132692.jpg
x subset_celeba/099203.jpg
x subset_celeba/002023.jpg
x subset_celeba/109604.jpg
x subset_celeba/013890.jpg
x subset_celeba/116042.jpg
x subset_celeba/023953.jpg
x subset_celeba/192342.jpg
x subset_celeba/156466.jpg
x subset_celeba/046107.jpg
x subset_celeba/092430.jpg
x subset_celeba/019403.jpg
x subset_celeba/082637.jpg
x subset_celeba/040562.jpg
x subset_celeba/116056.jpg
x subset_celeba/117348.jpg
x subset_ce

In [2]:
class CelebADataset(Dataset):
  def __init__(self, root_dir, transform=None):
    self.root_dir = root_dir
    self.transform = transform 
    self.image_names = os.listdir(self.root_dir)

  def __len__(self): 
    return len(self.image_names)

  def __getitem__(self, idx):
    img_path = os.path.join(self.root_dir, self.image_names[idx])
    img = Image.open(img_path).convert('RGB')
    if self.transform:
      img = self.transform(img)

    return img

In [5]:
workers = 1
batch_size = 128
image_size = 64
data_path = "./subset_celeba_2/"

transform = transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])
dataset = CelebADataset(data_path, transform=transform)
print("Total number of images:", len(dataset))

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

Total number of images: 372


RuntimeError: DataLoader worker (pid(s) 3204) exited unexpectedly

## Model Definition

In [None]:
# Custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=3):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ndf=64, nc=3):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Create the generator
netG = Generator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

print(netG)

In [None]:
# Create the discriminator
netD = Discriminator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

print(netD)

## Model Training

The first step is to train the discriminator. Remember that, for each training epoch, half of a batch should contain real images and the other half synthetic images. The synthetic images are created by the generator. Then, train the generator.


In [None]:
lr = 0.0002
num_epochs = 50

criterion = nn.BCELoss()

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1.
fake_label = 0.

optimizerD = optim.Adam(netD.parameters(), lr=lr)
optimizerG = optim.Adam(netG.parameters(), lr=lr)

In [None]:
# Training Loop

img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        ## TODO

        ## Train with all-fake batch
        ## TODO

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        ## TODO

        # Output training stats at every 50 iterations
        if i % 50 == 0:
            

        # Save Losses for plotting later
        G_losses.append(#TODO)
        D_losses.append(#TODO)

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

## Training Evolution

In [None]:
# Plot the evolution of the generator and discriminator losses
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# Plot the evolution of the generated images on fixed noise for fair comparison
#%%capture
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

## Challenge

Train the GAN on the entire CelebA dataset.