# Conditional GAN - MNIST

Make Your First GAN With PyTorch, 2020

In [None]:
# mount Drive to access data files

# from google.colab import drive
# drive.mount('./mount')

In [None]:
# import libraries

import torch
import torch.nn as nn
from torch.utils.data import Dataset

import pandas, numpy, random
import matplotlib.pyplot as plt

## Dataset Class

In [None]:
# dataset class


class MnistDataset(Dataset):
    def __init__(self, csv_file):
        self.data_df = pandas.read_csv(csv_file, header=None)
        pass

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, index):
        # image target (label)
        label = self.data_df.iloc[index, 0]
        target = torch.zeros((10))
        target[label] = 1.0

        # image data, normalised from 0-255 to 0-1
        image_values = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0

        # return label, image data tensor and target tensor
        return label, image_values, target

    def plot_image(self, index):
        img = self.data_df.iloc[index, 1:].values.reshape(28, 28)
        plt.title("label = " + str(self.data_df.iloc[index, 0]))
        plt.imshow(img, interpolation="none", cmap="Blues")
        pass

    pass

In [None]:
# load data
path = "./mnist_train.csv"
mnist_dataset = MnistDataset(path)

In [None]:
# check data contains images

mnist_dataset.plot_image(17)

## Data Functions

In [None]:
# functions to generate random data


def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data


def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data


# size here must only be an integer
def generate_random_one_hot(size):
    label_tensor = torch.zeros((size))
    random_idx = random.randint(0, size - 1)
    label_tensor[random_idx] = 1.0
    return label_tensor

## Discriminator Network

In [None]:
# discriminator class


class Discriminator(nn.Module):
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()

        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(784 + 10, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 1),
            nn.Sigmoid(),
        )

        # create loss function
        self.loss_function = nn.BCELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0
        self.progress = []

        pass

    def forward(self, image_tensor, label_tensor):
        # combine seed and label
        inputs = torch.cat((image_tensor, label_tensor))
        return self.model(inputs)

    def train(self, inputs, label_tensor, targets):
        # calculate the output of the network
        outputs = self.forward(inputs, label_tensor)

        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1
        if self.counter % 10 == 0:
            self.progress.append(loss.item())
            pass
        if self.counter % 10000 == 0:
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass

    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=["loss"])
        df.plot(
            ylim=(0),
            figsize=(16, 8),
            alpha=0.1,
            marker=".",
            grid=True,
            yticks=(0, 0.25, 0.5, 1.0, 5.0),
        )
        pass

    pass

## Test Discriminator

In [None]:
%%time
# test discriminator can separate real data from random noise

D = Discriminator()

for label, image_data_tensor, label_tensor in mnist_dataset:
    # real data
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
    # fake data
    D.train(
        generate_random_image(784),
        generate_random_one_hot(10),
        torch.FloatTensor([0.0]),
    )
    pass

In [None]:
# plot discriminator loss

D.plot_progress()

In [None]:
# manually run discriminator to check it can tell real data from fake

for i in range(4):
    label, image_data_tensor, label_tensor = mnist_dataset[random.randint(0, 60000)]
    print(D.forward(image_data_tensor, label_tensor).item())
    pass

for i in range(4):
    print(D.forward(generate_random_image(784), generate_random_one_hot(10)).item())
    pass

## Generator Network

In [None]:
# generator class


class Generator(nn.Module):
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()

        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(100 + 10, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 784),
            nn.Sigmoid(),
        )

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0
        self.progress = []

        pass

    def forward(self, seed_tensor, label_tensor):
        # combine seed and label
        inputs = torch.cat((seed_tensor, label_tensor))
        return self.model(inputs)

    def train(self, D, inputs, label_tensor, targets):
        # calculate the output of the network
        g_output = self.forward(inputs, label_tensor)

        # pass onto Discriminator
        d_output = D.forward(g_output, label_tensor)

        # calculate error
        loss = D.loss_function(d_output, targets)

        # increase counter and accumulate error every 10
        self.counter += 1
        if self.counter % 10 == 0:
            self.progress.append(loss.item())
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass

    def plot_images(self, label):
        label_tensor = torch.zeros((10))
        label_tensor[label] = 1.0
        # plot a 3 column, 2 row array of sample images
        f, axarr = plt.subplots(2, 3, figsize=(16, 8))
        for i in range(2):
            for j in range(3):
                axarr[i, j].imshow(
                    G.forward(generate_random_seed(100), label_tensor)
                    .detach()
                    .cpu()
                    .numpy()
                    .reshape(28, 28),
                    interpolation="none",
                    cmap="Blues",
                )
                pass
            pass
        pass

    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=["loss"])
        df.plot(
            ylim=(0),
            figsize=(16, 8),
            alpha=0.1,
            marker=".",
            grid=True,
            yticks=(0, 0.25, 0.5, 1.0, 5.0),
        )
        pass

    pass

## Test Generator Output

In [None]:
# check the generator output is of the right type and shape

G = Generator()

output = G.forward(generate_random_seed(100), generate_random_one_hot(10))

img = output.detach().numpy().reshape(28, 28)

plt.imshow(img, interpolation="none", cmap="Blues")

## Train GAN

In [None]:
# create Discriminator and Generator

D = Discriminator()
G = Generator()

In [None]:
%%time

# train Discriminator and Generator
# this took around 5 hours for me on the CPU

epochs = 12

for epoch in range(epochs):
    print("epoch = ", epoch + 1)

    # train Discriminator and Generator

    for label, image_data_tensor, label_tensor in mnist_dataset:
        # train discriminator on true
        D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))

        # random 1-hot label for generator
        random_label = generate_random_one_hot(10)

        # train discriminator on false
        # use detach() so gradients in G are not calculated
        D.train(
            G.forward(generate_random_seed(100), random_label).detach(),
            random_label,
            torch.FloatTensor([0.0]),
        )

        # different random 1-hot label for generator
        random_label = generate_random_one_hot(10)

        # train generator
        G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))

        pass

    pass

In [None]:
# plot discriminator error

D.plot_progress()

In [None]:
# plot generator error

G.plot_progress()

## Run Generator

In [None]:
# plot several outputs from the trained generator

G.plot_images(9)

In [None]:
# plot several outputs from the trained generator

G.plot_images(3)

In [None]:
# plot several outputs from the trained generator

G.plot_images(1)

In [None]:
# plot several outputs from the trained generator

G.plot_images(5)