# MNIST GAN

A generative adversarial network (GAN) is composed of a generator and classifier trained against each other. A generator is essentially a normal network run in reverse. Instead of determining the classification of a sample, it generates new samples of a given class. The generator and classifier are trained together, so that the generator produces new samples and the classifier tries to determine if each sample is "authentic" or generated.

We will begin with MNIST data again. Here we load the data and use it to initialize data loaders.

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import h5py as h5
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, sampler
import torch.nn.functional as F
from torchvision.utils import save_image

data = h5.File("../Lecture_07/Files/mnist.h5")
X_train = data['train']['inputs'][()].astype(np.float32)/255
X_test = data['test']['inputs'][()].astype(np.float32)/255
y_train = data['train']['targets'][()].astype(np.long)
y_test = data['test']['targets'][()].astype(np.long)

X_train = np.array([i.reshape(1, 28, 28) for i in X_train])
X_test = np.array([i.reshape(1, 28, 28) for i in X_test])
train_set = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
test_set = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))
batch_size = 50
train_loader = DataLoader(dataset=train_set, sampler=sampler.RandomSampler(train_set), batch_size=batch_size)
test_loader = DataLoader(dataset=test_set, sampler=sampler.RandomSampler(test_set), batch_size=batch_size)

We define a generator net to create new samples, and a discriminator net to tell samples apart.

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Conv2d(1, 32, 4, 2, 1, bias=False)) # in_channels, out_channels, kernel_size, stride=1, padding=0
        #self.layers.append(nn.BatchNorm2d(32))
        self.layers.append(nn.LeakyReLU(0.2))
        self.layers.append(nn.Conv2d(32, 64, 4, 2, 1, bias=False))
        self.layers.append(nn.BatchNorm2d(64))
        self.layers.append(nn.LeakyReLU(0.2))
        self.layers.append(nn.Conv2d(64, 1, 7, 1, 0, bias=False))
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = F.sigmoid(x.view(x.size(0), -1))
        return x

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.ConvTranspose2d(1, 64, 7, 1, 0, bias=False)) # in_channels, out_channels, kernel_size, stride=1, padding=0 
        self.layers.append(nn.BatchNorm2d(64))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False))
        self.layers.append(nn.BatchNorm2d(32))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.ConvTranspose2d(32, 1, 4, 2, 1, bias=False))
        self.layers.append(nn.Tanh())
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

discriminator = Discriminator()
generator = Generator()
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Now we train, and generate a set of images at the end of every epoch.

In [7]:
n_epochs = 25

truth_real = torch.Tensor(batch_size).fill_(1).float()
truth_fake = torch.Tensor(batch_size).fill_(0).float()
for epoch_n in range(n_epochs):
    for batch_n, (x, _) in enumerate(train_loader):
        # x is real samples, z is fake
        noise = torch.randn(batch_size, 1, 1, 1).float()
        z = generator(noise)
        discriminator_real = discriminator(x)
        discriminator_fake = discriminator(z)
        # train discriminator
        optimizer_discriminator.zero_grad()
        loss_discriminator = -torch.mean(torch.log(discriminator_real) + torch.log(1 - discriminator_fake))
        loss_discriminator.backward()
        optimizer_discriminator.step()
        # train generator
        for _ in range(10):
            noise = torch.randn(batch_size, 1, 1, 1).float()
            z = generator(noise)
            discriminator_fake = discriminator(z)
            optimizer_generator.zero_grad()
            loss_generator = -torch.mean(torch.log(discriminator_fake))
            loss_generator.backward()
            optimizer_generator.step()

        if batch_n%100 == 0:
            print('[{:02d}/{:02d}],[{:03d}/{:03d}], D(x): {:.4f}, D(G(z)): {:.4f}'.format(
              epoch_n+1, n_epochs, batch_n, len(train_loader), discriminator_real.data.mean(dim=0)[0], discriminator_fake.data.mean(dim=0)[0]))

    noise = torch.randn(batch_size, 1, 1, 1).float()
    noise.requires_grad = False
    z = generator(noise)
    save_image(z.data, 'Figs/mnist-fake-{:02d}.png'.format(epoch_n+1),
                   normalize=True)

[01/25],[000/1200], errD: 0.0000, D(x): 0.4932, errG: 0.0000, D(G(z)): 0.7336
[01/25],[100/1200], errD: 0.0000, D(x): 0.4937, errG: 0.0000, D(G(z)): 0.5439
[01/25],[200/1200], errD: 0.0000, D(x): 0.5115, errG: 0.0000, D(G(z)): 0.5230
[01/25],[300/1200], errD: 0.0000, D(x): 0.5199, errG: 0.0000, D(G(z)): 0.5073
[01/25],[400/1200], errD: 0.0000, D(x): 0.5271, errG: 0.0000, D(G(z)): 0.5008
[01/25],[500/1200], errD: 0.0000, D(x): 0.5244, errG: 0.0000, D(G(z)): 0.4960
[01/25],[600/1200], errD: 0.0000, D(x): 0.5318, errG: 0.0000, D(G(z)): 0.4754
[01/25],[700/1200], errD: 0.0000, D(x): 0.5638, errG: 0.0000, D(G(z)): 0.4822
[01/25],[800/1200], errD: 0.0000, D(x): 0.5670, errG: 0.0000, D(G(z)): 0.4450
[01/25],[900/1200], errD: 0.0000, D(x): 0.5489, errG: 0.0000, D(G(z)): 0.4484
[01/25],[1000/1200], errD: 0.0000, D(x): 0.5627, errG: 0.0000, D(G(z)): 0.4297
[01/25],[1100/1200], errD: 0.0000, D(x): 0.6154, errG: 0.0000, D(G(z)): 0.3880
[02/25],[000/1200], errD: 0.0000, D(x): 0.5301, errG: 0.0000, 

KeyboardInterrupt: 