In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

KeyboardInterrupt: ignored

In [None]:
from torchvision import datasets
import torchvision.transforms as transforms

# number of subprocesses to use for data loading
num_workers = 0

batch_size = 64  # how many samples per batch to load

transform = transforms.ToTensor()# convert data to torch.FloatTensor

train_data = datasets.MNIST(root='/content/gdrive/My Drive/GAN/data', train=True,       # get the training datasets
                                   download=True, transform=transform)

# prepare data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                           num_workers=num_workers)

In [None]:
images, labels = next(iter(train_loader))
images.shape[0]

In [None]:

# obtain one batch of training images
images, labels = next(iter(train_loader))

img = images[0].view(28, 28) #img = images[0].numpy(), img = np.squeeze(img) #or use #img = transforms.ToPILImage(mode='L')(data)
plt.figure(figsize=[3, 3])
plt.imshow(img, cmap="gray")
plt.show()


# Define the Model
A GAN is comprised of two adversarial networks, a discriminator and a generator.

## Discriminator
**The discriminator network is going to be a pretty typical linear classifier.** 
To make this network a universal function approximator, we'll need at least one hidden layer, and these hidden layers should have one key attribute:

All hidden layers will have a Leaky ReLu activation function applied to their outputs.

In [None]:
 
import torch.nn as nn
import torch.nn.functional as F

class Discriminator(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(Discriminator, self).__init__()
        
        # define hidden linear layers
        self.fc1 = nn.Linear(input_size, hidden_dim*4)
        self.fc2 = nn.Linear(hidden_dim*4, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)
        
        # final fully-connected layer
        self.fc4 = nn.Linear(hidden_dim, output_size)
        
        # dropout layer 
        self.dropout = nn.Dropout(0.3)
        
        
    def forward(self, x):
        # flatten image
        x = x.view(-1, 28*28)
        # all hidden layers
        x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        # final layer
        out = self.fc4(x)

        return out

# Generator
The generator network will be almost exactly the same as the discriminator network, except that we're applying a tanh activation function to our output layer.

## tanh Output
The generator has been found to perform the best with $tanh$ for the generator output, which scales the output to be between -1 and 1, instead of 0 and 1

In [None]:
class Generator(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(Generator, self).__init__()
        
        # define hidden linear layers
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim*4)
        
        # final fully-connected layer
        self.fc4 = nn.Linear(hidden_dim*4, output_size)
        
        # dropout layer 
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        # all hidden layers
        x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        # final layer with tanh applied
        out = F.tanh(self.fc4(x))

        return out

# Model hyperparameters


In [None]:
#discriminator hyperparameter
input_size=784# input image (28*28)
d_hidden_size=32 # Size of last hidden layer in the discriminator

d_output_size=1# either 0(fake) or 1 (real)

#generator  hyperparameter

z_size=100## Size of latent vector to give to generator, it is a hyperparameter and it could be any number 
g_hidden_size=32 # Size of First hidden layer in the generator 

g_output_size=784 # output generated image (28*28)

# Build the full network

In [None]:

# instantiate discriminator and generator
D = Discriminator(input_size, d_hidden_size, d_output_size)
G = Generator(z_size, g_hidden_size, g_output_size)

# check that they are as you expect
print(D)
print()
print(G)

# Discriminator and Generator Losses
Now we need to calculate the losses.

## Discriminator Losses
* For the discriminator, the total loss is the sum of the losses for real and fake images, d_loss = d_real_loss + d_fake_loss.
* Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.

The losses will by binary cross entropy loss with logits, which we can get with BCEWithLogitsLoss. This combines a sigmoid activation function and and binary cross entropy loss in one function.

For the real images, we want D(real_images) = 1. That is, we want the discriminator to classify the the real images with a label = 1, indicating that these are real. To help the discriminator generalize better, **the labels are reduced a bit from 1.0 to 0.9.** 

For this, we'll use the parameter smooth; if True, then we should smooth our labels. 

In PyTorch, this looks like labels = torch.ones(size) * 0.9

The discriminator loss for the fake data is similar. We want D(fake_images) = 0, where the fake images are the generator output, fake_images = G(z).

# Generator Loss
The generator loss will look similar only with flipped labels. The generator's goal is to get D(fake_images) = 1. In this case, the labels are flipped to represent that the generator is trying to fool the discriminator into thinking that the images it generates (fakes) are real!

In [None]:

# Calculate losses
def real_loss(D_out, smooth=False):
    # compare logits to real labels
    # smooth labels if smooth=True
    batch_size=D_out.size(0)#or D_out.shape[0]
    labels=torch.ones(batch_size)    

    if smooth==True:
      labels=labels * 0.9

    criterion = nn.BCEWithLogitsLoss() #beter in generalization
    loss = criterion(D_out.squeeze(),labels)
    return loss

def fake_loss(D_out):
    # compare logits to fake labels
    batch_size=D_out.size(0)
    labels=torch.zeros(batch_size)
    criterion = nn.BCEWithLogitsLoss() #beter in generalization
    loss = criterion(D_out.squeeze(),labels)
    return loss

# Optimizers
We want to update the generator and discriminator variables separately. So we'll define two separate Adam optimizers.

In [None]:
import torch.optim as optim

# learning rate for optimizers
lr = 0.002

# Create optimizers for the discriminator and generator
d_optimizer = optim.Adam(D.parameters(),lr)
g_optimizer =optim.Adam(G.parameters(),lr)

# Training
Training will involve alternating between training the discriminator and the generator. We'll use our functions real_loss and fake_loss to help us calculate the discriminator losses in all of the following cases.

## Discriminator training
* Compute the discriminator loss on real, training images
* Generate fake images
* Compute the discriminator loss on fake, generated images
* Add up real and fake loss
* Perform backpropagation + an optimization step to update the discriminator's weights
##Generator training
* Generate fake images
* Compute the discriminator loss on fake images, using flipped labels!
* Perform backpropagation + an optimization step to update the generator's weights
# Saving Samples

As we train, we'll also print out some loss statistics and save some generated "fake" samples.

In [None]:
import pickle as pkl

# training hyperparams
num_epochs = 100

# keep track of loss and generated, "fake" samples
samples = []
losses = []

print_every = 400

# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

# train the network
D.train()
G.train()
for epoch in range(num_epochs):
    
    for batch_i, (real_images, _) in enumerate(train_loader):
                
        batch_size = real_images.size(0)
        
        ## Important rescaling step ## 
        real_images = real_images*2 - 1  # rescale input images from [0,1) to [-1, 1)
        
        # ============================================
        #            TRAIN THE DISCRIMINATOR
        # ============================================
        
        d_optimizer.zero_grad()
        
        # 1. Train with real images

        # Compute the discriminator losses on real images 
        # smooth the real labels
        D_real = D(real_images)
        d_real_loss = real_loss(D_real, smooth=True)
        
        # 2. Train with fake images
        
        # Generate fake images
        # gradients don't have to flow during this step
        with torch.no_grad():
            z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            z = torch.from_numpy(z).float()
            fake_images = G(z)
        
        # Compute the discriminator losses on fake images        
        D_fake = D(fake_images)
        d_fake_loss = fake_loss(D_fake)
        
        # add up loss and perform backprop
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        
        # =========================================
        #            TRAIN THE GENERATOR
        # =========================================
        g_optimizer.zero_grad()
        
        # 1. Train with fake images and flipped labels
        
        # Generate fake images
        z = np.random.uniform(-1, 1, size=(batch_size, z_size))
        z = torch.from_numpy(z).float()
        fake_images = G(z)
        
        # Compute the discriminator losses on fake images 
        # using flipped labels!
        D_fake = D(fake_images)
        g_loss = real_loss(D_fake) # use real loss to flip labels
        
        # perform backprop
        g_loss.backward()
        g_optimizer.step()

        # Print some loss stats
        if batch_i % print_every == 0:
            # print discriminator and generator loss
            print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
                    epoch+1, num_epochs, d_loss.item(), g_loss.item()))

    
    ## AFTER EACH EPOCH##
    # append discriminator loss and generator loss
    losses.append((d_loss.item(), g_loss.item()))
    
    # generate and save sample, fake images
    G.eval() # eval mode for generating samples
    samples_z = G(fixed_z)
    samples.append(samples_z)
    G.train() # back to train mode


# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)

# Training loss
Here we'll plot the training losses for the generator and discriminator, recorded after each epoch.

In [None]:
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()

# Test our generator

In [None]:
z = np.random.uniform(-1, 1, size=(1, z_size))
z = torch.from_numpy(z).float()
fake_images = G(z)
print(fake_images.shape)
img = fake_images.view(28, 28)
img = img.detach()
plt.imshow(img, cmap="gray")

In [None]:

# helper function for viewing a list of passed in sample images
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        img = img.detach()
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')

In [None]:
# Load samples from generator, taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pkl.load(f)

In [None]:
# -1 indicates final epoch's samples (the last in the list)
view_samples(-1, samples)

In [None]:
rows = 10 # split epochs into 10, so 100/10 = every 10 epochs
cols = 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        img = img.detach()
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

In [None]:

# randomly generated, new latent vectors
sample_size=16
rand_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
rand_z = torch.from_numpy(rand_z).float()

G.eval() # eval mode
# generated samples
rand_images = G(rand_z)

# 0 indicates the first set of samples in the passed in list
# and we only have one batch of samples, here
view_samples(0, [rand_images])