# Generating Hand Written Digits With a GAN in PyTorch from MNIST

Generative Adversarial Network (GAN) is a pair networks, the generator network and the discriminator network. 

The goal is for the generator to follow the data distribution of a dataset, by using the discriminator to find the distribution of the dataset and the generator, and backpropagating the classification through the discriminator to the generator so that the distribution of the generator output converges to the dataset distribution.

In this notebook, we will make the discriminator classify between each of the 10 digits, plus whether the image is generated from the generator.

This allows us to have control over the generator to generate a sample from a specified class. This type of GAN is called a Conditional GAN.

In [1]:
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as utils
import matplotlib.pyplot as plt
from IPython import display
import numpy as np
import os

# Defining Our Discriminator

Since many people have tried to make good discriminators over the years and made the architecture available for everyone, we will be using one for our discriminator. We will be using the Resnet 18 network.

To make it work for the MNIST dataset, we redefine the input to accept images with a single channel, and output 11 classes. 1 class for each of the 10 classes in the MNIST dataset, and 1 class for whether the image is from the generator.

In [2]:
def discriminator():
    import torchvision.models as models
    model = models.resnet18()
    model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.avgpool = nn.Flatten()
    model.fc = nn.Linear(512, 11)
    return model

# Defining Our Generator

The generator will take noise and a encoding of which digit to generate as a concatenated vector, and output an image with the size 29 x 29 x 1.

In [3]:
class ReZeroShortcut(nn.Module):
    def __init__(self, alpha=0.0):
        super(ReZeroShortcut, self).__init__()
        self.alpha = nn.parameter.Parameter(torch.ones(1) * alpha)

    def forward(self, shortcut, x):
        return shortcut + self.alpha * x

class ReZeroBlock(nn.Module):
    def __init__(self, cin, cout):
        super(ReZeroBlock, self).__init__()
        self.rezero = ReZeroShortcut()
        self.conv1 = nn.Conv2d(cin, cin, 3, padding=1)
        self.conv2 = nn.Conv2d(cin, cin, 3, padding=1)
        self.conv3 = nn.Conv2d(cin, cout, 3, padding=1)
        self.elu = nn.ELU(1)

    def forward(self, x):
        y = self.conv1(x)
        y = self.elu(y)
        y = self.conv2(y)
        y = self.elu(y)
        y = self.rezero(x,y)

        return self.conv3(y)

class Generator(nn.Module):
    def __init__(self, classes, input_noise_dim):
        super(Generator, self).__init__()
        self.conv1 = nn.ConvTranspose2d(input_noise_dim+classes, 512, 3)
        self.conv2 = nn.ConvTranspose2d(512, 256, 3)
        self.conv3 = nn.ConvTranspose2d(256, 128, 3)

        self.rezero1 = ReZeroBlock(128,64)
        self.rezero2 = ReZeroBlock(64, 32)
        self.rezero3 = ReZeroBlock(32, 1)

        self.elu = nn.ELU(1)
        self.upsample1 = nn.Upsample(14)
        self.upsample2 = nn.Upsample(28)
        self.batch1 = nn.BatchNorm2d(512)
        self.batch2 = nn.BatchNorm2d(256)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.elu(x)

        x = self.conv2(x)
        x = self.batch2(x)
        x = self.elu(x)

        x = self.conv3(x)

        x = self.rezero1(x)
        x = self.upsample1(x)

        x = self.rezero2(x)
        x = self.upsample2(x)

        x = self.rezero3(x)

        return x

# Defining Our Utility Functions

In [4]:
#Converts an integer into the one hot encoding of the integer.
def one_hot_encode(label, device, n_class=10):
  eye = torch.eye(n_class, device=device)
  return eye[label].view(-1, n_class, 1, 1)

#Concatenate noise and one hot encoding into a single vector.
def concat_noise_label(noise, label, device):
  label = one_hot_encode(label, device)
  return torch.cat((noise, label), dim=1)

#Generate noise to feed the generator with the one hot encoding.
def make_noise(return_label = False,sequence=None):
    if sequence:
      noise = torch.randn(len(sequence), input_noise_dim, 1, 1, device=device)
      label = torch.tensor([int(i) for i in sequence], device=device)
    else:
      noise = torch.randn(batch_size, input_noise_dim, 1, 1, device=device)
      label = torch.randint(10, (batch_size,), dtype=torch.long, device=device)
    if return_label:
        return concat_noise_label(noise, label, device),label
    return concat_noise_label(noise, label, device)

#Save an image of the output from the generator by feeding in the same noise every time.
def save_image(name):
    images = generator(fixed_noise)
    utils.save_image(images.detach(),f'output/{name}.png',nrow=10)

#Save an image where the given digits are generated side by side.
def save_sequence(sequence):
  utils.save_image(generator(make_noise(sequence=sequence)),f'output/sequence.png',nrow=len(sequence))

def plot(*data):
    plt.clf()
    ax = plt.gca()
    ax.yaxis.tick_right()
    ax.yaxis.set_ticks_position('both')
    ax.yaxis.grid(True)
    
    for i in data:
        plt.plot(i)
    
    plt.legend(['Real','Fake','Generator'], loc='lower left')
    
    display.clear_output(wait=True)
    display.display(plt.gcf())

# Set up our dataset
We download the MNIST dataset from PyTorch, and define how we load the dataset.

We will batch 100 images at a time, transform the image into a tensor and normalize with the mean and the standard deviation of 0.5.

In [5]:
batch_size = 100
input_noise_dim = 128
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])
dataset = torchvision.datasets.MNIST(root='',train=True,download=True,transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


# Set up both ours model and optimizers for the generator and the discriminator

We initialize the generator and the discriminator network, and define the optimizer. 

Notice that when we train the generator, we only want the generator's parameter to update, even though our gradient goes through the discriminator. 

We can handle that by defining 2 optimizers, one for updating the generator, and one for updating the discriminator.

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = Generator(10,input_noise_dim=input_noise_dim).to(device)
discriminator = discriminator().to(device)
criterion = nn.CrossEntropyLoss()

lr,beta1 = 0.0001, 0.5
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)

# Final Setup Before Training

We define a fixed noise to use during the training, to see how well our generator is doing. 

Since the noise is unchanged, the only difference we see between images are what effect the training had on the generator.

In [7]:
fixed_noise = concat_noise_label(torch.randn(100, input_noise_dim, 1, 1, device=device), torch.tensor(list(range(10))*10, dtype=torch.long, device=device), device)
fake_label = torch.tensor([10]*batch_size).to(device)
os.makedirs('output',exist_ok=True)
utils.save_image(next(dataloader.__iter__())[0][:100],f'output/reals.png',nrow=10,normalize=True)

# Train Our Networks One After the Other

To make our GAN work together, we need to make the networks learn one after the other. 

To train the discriminator, we first load images from our dataset and classify the image classes, the same way a classifier is trained. Then we generate samples from our generator, and train the discriminator to classify them as the fake class, so that it can distinguish between our dataset and the generator's distribution.

We then combine the losses and back propagate so the loss is minimized.

To train the generator, we sample from the generator and feed it to the discriminator, then train only the generator so that the discriminator classifies it as the class that we want.

We continue this until the generator distribution matches the dataset distribution, and the discriminator can no longer tell the difference apart.

Periodically, the generator image is saved in the output directory.

In [None]:
reals,fakes,gens=[],[],[]
for epoch in range(10):
    for i, (image,label) in enumerate(dataloader):
        discriminator.zero_grad()

        image=image.to(device)
        label=label.to(device)

        output1 = discriminator(image)
        real_loss = criterion(output1, label)

        noise = make_noise()
        fake_image = generator(noise)
        output = discriminator(fake_image.detach())
        fake_loss = criterion(output, fake_label)

        (real_loss + fake_loss).backward()
        optimizerD.step()

        for j in range(5):
          generator.zero_grad()
          noise, label = make_noise(return_label=True)
          fake_image = generator(noise)
          output = discriminator(fake_image)
          gen_loss = criterion(output, label)

          gen_loss.backward()
          optimizerG.step()


        reals.append(real_loss.item())
        fakes.append(fake_loss.item())
        gens.append(gen_loss.item())
        plot(reals,fakes,gens)
        
        for i,(real,fake,gen) in enumerate(zip(reals,fakes,gens)):
            print(i,real,fake,gen)

        save_image(str(i))
        

Check the generator output by giving it digits, and confirm it's able to generate the digits you give.

In [None]:
save_sequence('987654321')