# Build a Conditional GAN

### Goals
In this notebook, you're going to make a conditional GAN in order to generate hand-written images of digits, conditioned on the digit to be generated (the class vector). This will let you choose what digit you want to generate.

You'll then do some exploration of the generated images to visualize what the noise and class vectors mean.  

### Learning Objectives
1.   Learn the technical difference between a conditional and unconditional GAN.
2.   Understand the distinction between the class and noise vector in a conditional GAN.



In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f9280b9f930>

In [2]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    image_tensor = (image_tensor+1)/2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [3]:
# Generator class
class Generator(nn.Module):
    def __init__(self, input_dim=10, im_chan = 1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True)

        )

    def  make_gen_block(self, input_channels, output_channels, kernel_size = 3, stride = 1, final_layer=False):
        if final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels ,kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace = True)
            )
        else: return nn.Sequential(
            nn.Conv2d(input_channels, output_channels ,kernel_size, stride),
            nn.Tanh()
        )

    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.input_dim, 1, 1)

    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

def get_noise(n_sample, input_dim, device='cpu'):
    return torch.randn(n_sample, input_dim, device=device)


In [4]:
# Discriminator class
class Discriminator(nn.Module):
    def __init__(self, im_chan = 1, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_disc_block(self, input_channels , output_channels, kernel_size=4, stride=1, final_layer=False):
        if final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels ,kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace = True)
            )
        else: return nn.Sequential(
            nn.Conv2d(input_channels, output_channels ,kernel_size, stride),
        )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [8]:
# conditoned vector creation

import torch.nn.functional as F

def get_one_hot_labels(labels, n_classes):
    return F.one_hot(labels, n_classes)

assert (
    get_one_hot_labels(
        labels=torch.Tensor([0, 2, 1]).long(), 
        n_classes=3).tolist() == [
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 0]
])

print('Success...')

Success...


In [9]:
def combine_vectors(x, y):
    return torch.cat((x.float(), y.float()), 1)

combined = combine_vectors(torch.tensor([[1, 2],[3, 4]]), torch.tensor([[5,6],[7, 8]]))
assert torch.all(combined == torch.tensor([[1, 2, 5,6], [3, 4, 7, 8]]))
assert (type(combined[0][0].item()) == float)
combined = combine_vectors(torch.randn(1, 4, 5), torch.randn(1, 8, 5));
assert tuple(combined.shape) == (1, 12, 5)
assert tuple(combine_vectors(torch.randn(1, 10, 12).long(), torch.randn(1, 20, 12).long()).shape) == (1, 30, 12)
print("Success!")

Success!


In [None]:
mnist_shape = (1, 28, 28)
n_classes = 10

criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [12]:
def get_input_dimension(z_dim, mnist_shape, n_classes):
    gen_input_dim = z_dim + n_classes
    disc_im_chan = mnist_shape[0] + n_classes
    return gen_input_dim, disc_im_chan

def test_get_input_dimension():
    gen_dim, disc_dim = get_input_dimension(23, (12, 23, 52), 9)
    assert gen_dim == 32
    assert disc_dim == 21
test_get_input_dimension()