### Goal
In this notebook, you're going to create your first generative adversarial network (GAN) for this course! Specifically, you will build and train a GAN that can generate hand-written images of digits (0-9). You will be using PyTorch in this specialization, so if you're not familiar with this framework, you may find the [PyTorch documentation](https://pytorch.org/docs/stable/index.html) useful. The hints will also often include links to relevant documentation.

### Learning Objectives
1.   Build the generator and discriminator components of a GAN from scratch.
2.   Create generator and discriminator loss functions.
3.   Train your GAN and visualize the generated images.


In [2]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

torch.manual_seed(34)



  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f7bd947c150>

In [3]:
def show_tensor_images(image_tensor, num_images=25, size=(1,28,28)):
    '''
    
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# Let's create generator using pytorch

In [4]:
## Generator Block
def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

# unit test for generator block
def test_gen_block(in_features, out_features, num_test=1000):
    block = get_generator_block(in_features, out_features)

    assert len(block) == 3
    assert type(block[0]) == nn.Linear
    assert type(block[1]) == nn.BatchNorm1d
    assert type(block[2]) == nn.ReLU

    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)

    assert tuple(test_output.shape) == (num_test, out_features)
    assert test_output.std() > 0.55
    assert test_output.std() < 0.65

test_gen_block(25, 12)
test_gen_block(15, 28)
print("Gen Block test successful...")

Gen Block test successful...


In [5]:
## Generator class

class Generator(nn.Module):
    def __init__(self, z_dim = 10, im_dim = 784, hidden_dim = 128):
        super(Generator, self).__init__()

        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim*2),
            get_generator_block(hidden_dim*2, hidden_dim*4),
            get_generator_block(hidden_dim*4, hidden_dim*8),
            nn.Linear(hidden_dim*8, im_dim),
            nn.Sigmoid()
        )

    def forward(self, noise):
        return self.gen(noise)

    def get_gen(self):
        return self.gen


# unit test for generator class
def test_generator(z_dim, im_dim, hidden_dim, num_test=1000):
    gen = Generator(z_dim, im_dim, hidden_dim).get_gen()

    assert str(gen.__getitem__(4)).replace(' ', '') == f'Linear(in_features={hidden_dim * 8},out_features={im_dim},bias=True)'
    assert str(gen.__getitem__(5)).replace(' ', '') == 'Sigmoid()'

    test_input = torch.randn(num_test, z_dim)
    test_output = gen(test_input)

    assert tuple(test_output.shape) == (num_test, im_dim)
    assert test_output.max() < 1, "Make sure to use a sigmoid function"
    assert test_output.min() > 0, "Make sure to use a sigmoid function"
    assert test_output.std() > 0.05, "Don't use batchnorm here"
    assert test_output.std() < 0.15, "Don't use batchnorm here"

test_generator(5, 10, 20)
test_generator(20, 8, 24)
print('Gen class tested successfully...')

Gen class tested successfully...


# Let's create some noise...

In [7]:
def get_noise(n_sample, z_dim, device='cpu'):
    return torch.randn(n_sample, z_dim, device=device)

# unit test for noise generator

def test_get_noise(n_sample, z_dim, device='cpu'):
    print(torch.cuda.is_available())
    noise = get_noise(n_sample, z_dim, device)

    assert tuple(noise.shape) == (n_sample, z_dim)
    assert torch.abs(noise.std() - torch.tensor(1.0)) < 0.01
    assert str(noise.device).startswith(device)

test_get_noise(1000, 100, device='cpu')
if torch.cuda.is_available():
    test_get_noise(1000, 32, device='cuda')
print('Success...')

False
Success...


# Time to create discriminator...

In [9]:
# discriminator block
def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2, inplace = True)

    )

# unit test for discriminator block
def test_disc_block(in_features, out_features, num_test = 1000):
    block = get_discriminator_block(in_features, out_features)

    assert len(block) == 2
    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)

    assert tuple(test_output.shape) == (num_test, out_features)

    # check LeakyReLU slop is about 0.2
    assert -test_output.min()/test_output.max() > 0.1
    assert -test_output.min()/test_output.max() < 0.3
    assert test_output.std() > 0.3
    assert test_output.std() < 0.5

    
    assert str(block.__getitem__(0)).replace(' ', '') == f'Linear(in_features={in_features},out_features={out_features},bias=True)'        
    assert str(block.__getitem__(1)).replace(' ', '').replace(',inplace=True', '') == 'LeakyReLU(negative_slope=0.2)'

test_disc_block(25, 12)
test_disc_block(15, 28)
print("Disc block success...")

Disc block success...


In [10]:
# Discriminator class
class Discriminator(nn.Module):
    def __init__(self, im_dim = 784, hidden_dim =128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim*4),
            get_discriminator_block(hidden_dim*4, hidden_dim*2),
            get_discriminator_block(hidden_dim*2, hidden_dim),
            nn.Linear(hidden_dim, 1)

        )

    def forward(self, image):
        return self.disc(image)

    def get_disc(self):
        return self.disc

# unit test for Discriminator class
def test_discriminator(z_dim, hidden_dim, num_test = 1000):
    disc = Discriminator(z_dim, hidden_dim).get_disc()

    assert len(disc) == 4
    assert type(disc.__getitem__(3)) == nn.Linear

    test_input = torch.randn(num_test, z_dim)
    test_ouput = disc(test_input)
    assert tuple(test_ouput.shape) == (num_test, 1)

test_discriminator(5, 10)
test_discriminator(20, 8)
print("Disc class sccusse...")

Disc class sccusse...


## Training
Now you can put it all together!
First, you will set your parameters:
  *   criterion: the loss function
  *   n_epochs: the number of times you iterate through the entire dataset when training
  *   z_dim: the dimension of the noise vector
  *   display_step: how often to display/visualize the images
  *   batch_size: the number of images per forward/backward pass
  *   lr: the learning rate
  *   device: the device type, here using a GPU (which runs CUDA), not CPU

Next, you will load the MNIST dataset as tensors using a dataloader.



In [None]:
# Set your parameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
if torch.cuda.is_available():
    device = 'cuda'
else: device = 'cpu'
im_dim = 784


In [None]:

# Load MNIST dataset as tensors
dataloader = DataLoader(
    MNIST('.', download=False, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

In [15]:
# Initilize generator and discriminator
gen = Generator(z_dim, im_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr = lr)

disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr)


In [17]:
# create discriminator loss
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device)
    fake = gen(fake_noise)

    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))

    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))

    disc_loss = (disc_real_loss + disc_fake_loss)/2
    return disc_loss

# test disc reasonable
def test_disc_reasonable(num_images = 10):
    z_dim = 64
    gen = torch.zeros_like
    disc = nn.Identity()
    criterion = torch.mul
    real = torch.ones(num_images, 1)
    disc_loss = get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu')

    assert tuple(disc_loss.shape) == (num_images, z_dim)
    assert torch.all(torch.abs(disc_loss - 0.5) < 1e-5)

    gen = torch.ones_like
    disc = nn.Identity()
    criterion = torch.mul
    real = torch.zeros(num_images, 1)
    assert torch.all(torch.abs(get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu')) < 1e-5)

def test_disc_loss(max_tests = 10):
    z_dim = 64
    gen = Generator(z_dim, im_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr = lr)

    disc = Discriminator().to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr)
    num_steps = 0

    for real, _ in dataloader:
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)

        disc_opt.zero_grad()
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, 'cpu')
        assert (disc_loss - 0.68).abs() < 0.05

        disc_loss.backward(retain_graph = True)

        assert gen.gen[0][0].weight.grad is None

        old_weight = disc.disc[0][0].weight.data.clone()
        disc_opt.step()
        new_weight = disc.disc[0][0].weight.data

        assert not torch.all(torch.eq(old_weight, new_weight))
        num_steps+=1
        if num_steps >= max_tests: break

test_disc_reasonable()
# test_disc_loss()
print("(success...")


(success...


In [22]:
# create generator loss
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device)
    fake = gen(fake_noise)

    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss

# unit generator reasonable
def test_gen_reasonable(num_images = 10):
    z_dim = 64
    gen = torch.zeros_like
    disc = nn.Identity()

    criterion = torch.mul
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, device)
    assert torch.all(torch.abs(gen_loss_tensor ) < 1e-5)
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)

    gen = torch.ones_like
    disc = nn.Identity()
    criterion = torch.mul
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, device)
    assert torch.all(torch.abs(gen_loss_tensor - 1) < 1e-5)
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)

def test_gen_loss(num_images):
    z_dim = 64
    gen = Generator(z_dim, im_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr = lr)

    disc = Discriminator().to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr)

    gen_loss = get_gen_loss(gen, disc, criterion, num_images, z_dim, device)

    assert (gen_loss - 0.7).abs() < 0.1
    gen_loss.backward()
    old_weight = gen.gen[0][0].weight.clone()
    gen_opt.step()
    new_weight = gen.gen[0][0].weight
    assert not torch.all(torch.eq(old_weight, new_weight))

test_gen_reasonable()
test_gen_loss(18)
print("Gen loss success...")

Gen loss success...


# First GAN arch execution

In [23]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False

In [None]:
for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        real = real.view(cur_batch_size, -1).to(device)

        disc_opt.zero_grad()
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
        disc_loss.backward(retain_graph = True)
        disc_opt.step()

        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()

        if test_generator:
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print("Runtime tests have failed")

        mean_discriminator_loss += disc_loss.item()/display_step
        mean_generator_loss += gen_loss.item()/display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1
