# Handwritten Digits - MNIST GAN With Improved Training

Make Your First GAN With PyTorch, 2020

In [32]:
from rich import print

## import libraries

In [33]:
import random
import torch
import torch.nn as nn

import pandas as pd 
import numpy
import matplotlib.pyplot as plt

import dill as pickle

## Dataset Class

In [34]:
import mnist_data

In [35]:
# load training data

train_csv = mnist_data.datadir.joinpath('mnist_train.csv')
mnist_dataset = mnist_data.MnistDataset(train_csv)

## Data Functions

In [36]:
# functions to generate random data

def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data


def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data

## Discriminator Network

In [37]:
from mnist_classifier import Classifier as Discriminator

model = nn.Sequential(
    nn.Linear(784, 200),
    nn.LeakyReLU(0.02),

    nn.LayerNorm(200),

    nn.Linear(200, 1),
    nn.Sigmoid()
)

D = Discriminator(model)
D.loss_function = nn.BCELoss()
D.optimiser = torch.optim.Adam(D.parameters(), lr=0.0001)

In [38]:
print(f"{D.model = }")
print(f'{D.loss_function = }')
print(f'{D.optimiser = }')

## Test Discriminator

In [None]:
%%time
# test discriminator can separate real data from random noise

for label, image_data_tensor, target_tensor in mnist_dataset:
    # real data
    D.train(image_data_tensor, torch.FloatTensor([1.0]), print_counter=True)
    # fake data
    D.train(generate_random_image(784), torch.FloatTensor([0.0]), print_counter=True)

In [None]:
# plot discriminator loss

# D.plot_progress()
D.plot_progress(ylim=(0, 0.005))

In [None]:
# manually run discriminator to check it can tell real data from fake

for i in range(4):
    image_data_tensor = mnist_dataset[random.randint(0, 60000)][1]
    print(f"{D.forward(image_data_tensor).item():.3f}")

for i in range(4):
    print(f"{D.forward(generate_random_image(784)).item():.2e}")

## Generator Network

In [None]:
from mnist_generator import Generator

gmodel = nn.Sequential(
    nn.Linear(100, 200),
    nn.LeakyReLU(0.02),

    nn.LayerNorm(200),

    nn.Linear(200, 784),
    nn.Sigmoid()
)

G = Generator(gmodel)
G.optimiser = torch.optim.Adam(G.parameters(), lr=0.0001)

In [None]:
print(f"{G.model = }")
print(f'{G.optimiser = }')

## Test Generator Output

In [None]:
# check the generator output is of the right type and shape

output = G.forward(generate_random_seed(100))

img = output.detach().numpy().reshape(28,28)

plt.imshow(img, interpolation='none', cmap='Blues')

## Train GAN

In [None]:
%%time 

# create Discriminator and Generator

D = Discriminator(model)
D.loss_function = nn.BCELoss()
D.optimiser = torch.optim.Adam(D.parameters(), lr=0.0001)

G = Generator(gmodel)
G.optimiser = torch.optim.Adam(G.parameters(), lr=0.0001)

# train Discriminator and Generator

epochs = 4
# epochs = 8

for epoch in range(epochs):
    print(f'Starting epoch {epoch} of {epochs}')
    for label, image_data_tensor, target_tensor in mnist_dataset:

        # train discriminator on true
        D.train(image_data_tensor, torch.FloatTensor([1.0]), print_counter=True)

        # train discriminator on false
        # use detach() so gradients in G are not calculated
        D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]), print_counter=True)
#         D.train(G.forward(generate_random_image(100)).detach(), torch.FloatTensor([0.0]), print_counter=True)

        # train generator
        G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))

In [None]:
# plot discriminator error

# D.plot_progress(yticks=(0, 0.25, 0.5, 5))
D.plot_progress()

In [None]:
# plot generator error

# G.plot_progress(yticks=(0, 0.25, 0.5, 5))
G.plot_progress(xlim=(0, 5000))

## Run Generator

In [None]:
# plot several outputs from the trained generator

# plot a 3 column, 2 row array of generated images
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
    for j in range(3):
        output = G.forward(generate_random_seed(100))
        img = output.detach().numpy().reshape(28, 28)
        axarr[i,j].imshow(img, interpolation='none', cmap='Blues')

## Seed Experiments

In [None]:
seed1 = generate_random_seed(100)
out1 = G.forward(seed1)
img1 = out1.detach().numpy().reshape(28,28)
plt.imshow(img1, interpolation='none', cmap='Blues')

In [None]:
seed2 = generate_random_seed(100)
out2 = G.forward(seed2)
img2 = out2.detach().numpy().reshape(28,28)
plt.imshow(img2, interpolation='none', cmap='Blues')

In [None]:
# plot several outputs from the trained generator

count = 0

# plot a 3 column, 2 row array of generated images
f, axarr = plt.subplots(3,4, figsize=(16,8))
for i in range(3):
    for j in range(4):
        seed = seed1 + (seed2 - seed1)/11 * count
        output = G.forward(seed)
        img = output.detach().numpy().reshape(28,28)
        axarr[i,j].imshow(img, interpolation='none', cmap='Blues')
        count = count + 1

In [None]:
# sum of seeds

seed3 = seed1 + seed2
out3 = G.forward(seed3)
img3 = out3.detach().numpy().reshape(28,28)
plt.imshow(img3, interpolation='none', cmap='Blues')

In [None]:
# difference of seeds

seed4 = seed1 - seed2
out4 = G.forward(seed4)
img4 = out4.detach().numpy().reshape(28,28)
plt.imshow(img4, interpolation='none', cmap='Blues')

In [None]:
# difference of seeds

seed4 = seed1 * seed2
out4 = G.forward(seed4)
img4 = out4.detach().numpy().reshape(28,28)
plt.imshow(img4, interpolation='none', cmap='Blues')

## pickle and save GAN

In [None]:
D.pickle('discriminator_mnist.pkl')
G.pickle('generator_mnist.pkl')

In [None]:
gan = dict(discriminator=D, generator=G)

with open('gan_mnist.pkl', 'wb') as f:
    pickle.dump(gan, f)

## load from pickled GAN

In [None]:
with open('discriminator_mnist.pkl', 'rb') as f:
    Dn = pickle.load(f)
    
with open('generator_mnist.pkl', 'rb') as f:
    Gn = pickle.load(f)

In [None]:
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
    for j in range(3):
        output = Gn.forward(generate_random_seed(100))
        img = output.detach().numpy().reshape(28, 28)
        axarr[i,j].imshow(img, interpolation='none', cmap='Blues')

In [None]:
with open('gan_mnist.pkl', 'rb') as f:
    new_gan = pickle.load(f)

In [None]:
output = new_gan['generator'].forward(generate_random_seed(100))
img = output.detach().numpy().reshape(28, 28)
plt.imshow(img, interpolation='none', cmap='Blues')