In [1]:
import numpy as np
import random
import torch
from torch import nn
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
%matplotlib inline

# HYPERPARAMETERS
BATCH_SIZE = 128
LEARNING_RATE = 0.0002
GENERATOR_NUM_FEATURES = 100
GENERATOR_HIDDEN_SIZE, DISCRIMINATOR_HIDDEN_SIZE = 128, 128
NUM_EPOCHS = 100
DROPOUT_PROB = 0.4

device = 'cuda'

In [2]:
class Generator(nn.Module):

  def __init__(self, input_size, hidden_dim):
    super(Generator, self).__init__()
    output_size = 28 * 28

    self.out = nn.Sequential(
      nn.Linear(input_size, hidden_dim),
      nn.LeakyReLU(0.2),
      # nn.Dropout(DROPOUT_PROB),
      nn.Linear(hidden_dim, 2*hidden_dim),
      nn.LeakyReLU(0.2),
      # nn.Dropout(DROPOUT_PROB),
      nn.Linear(2*hidden_dim, 4*hidden_dim),
      nn.LeakyReLU(0.2),
      # nn.Dropout(DROPOUT_PROB),
      nn.Linear(4*hidden_dim, output_size),
      nn.Tanh()
    )

  def forward(self, x):
    return self.out(x)


class Discriminator(nn.Module):

  def __init__(self, hidden_dim):
    super(Discriminator, self).__init__()
    input_size = 28 * 28
    output_size = 1
    
    self.out = nn.Sequential(
        nn.Linear(input_size, 4*hidden_dim),
        nn.LeakyReLU(0.2),  
        nn.Dropout(DROPOUT_PROB),
        nn.Linear(4*hidden_dim, 2*hidden_dim),
        nn.LeakyReLU(0.2),
        nn.Dropout(DROPOUT_PROB),
        nn.Linear(2*hidden_dim, hidden_dim),
        nn.LeakyReLU(0.2),
        nn.Dropout(DROPOUT_PROB),
        nn.Linear(hidden_dim, output_size)
    )
      
      
  def forward(self, x):
    # flatten image
    x = x.view(-1, 28*28)
    return self.out(x)

D = Discriminator(DISCRIMINATOR_HIDDEN_SIZE)
G = Generator(GENERATOR_NUM_FEATURES, GENERATOR_HIDDEN_SIZE)
print(D)
print(G)

#Passing to the GPU
D = D.to(device)
G = G.to(device)

D = D.float()
G = G.float()

Discriminator(
  (out): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.4, inplace=False)
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Dropout(p=0.4, inplace=False)
    (6): Linear(in_features=256, out_features=128, bias=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Dropout(p=0.4, inplace=False)
    (9): Linear(in_features=128, out_features=1, bias=True)
  )
)
Generator(
  (out): Sequential(
    (0): Linear(in_features=100, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=256, out_features=512, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=512, out_features=784, bias=True)
    (7): Tanh()
  )
)


In [3]:
def mnist_data():
  compose = transforms.Compose(
      [transforms.ToTensor(),
       transforms.Normalize((.5), (.5))
      ])
  out_dir = './dataset'
  return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

data = mnist_data()
# train_data = torch.utils.data.Subset(data, random.sample(range(0, len(data)), 10000))

# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)

num_batches = len(data_loader)
print("Number of batches = {} with batch size = {}".format(num_batches, BATCH_SIZE))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw
Processing...
Done!
Number of batches = 469 with batch size = 128


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
# Calculate losses
def real_loss(D_out):
  batch_size = D_out.size(0)
  labels = torch.ones(batch_size).to(device) # real labels = 1
  # numerically stable loss
  criterion = nn.BCEWithLogitsLoss()
  # calculate loss
  loss = criterion(D_out.squeeze(), labels)
  return loss

def fake_loss(D_out):
  batch_size = D_out.size(0)
  labels = torch.zeros(batch_size).to(device) # fake labels = 0
  criterion = nn.BCEWithLogitsLoss()
  # calculate loss
  loss = criterion(D_out.squeeze(), labels)
  return loss

def calc_real_acc(D_out):
  labels = torch.ones(D_out.size()[0]).to(device) # real labels = 1
  # print("real", D_out.squeeze()[:10])
  correct_pred = ((D_out.squeeze() >= 0.5) == labels).float().sum()
  return correct_pred.item()

def calc_fake_acc(D_out):
  labels = torch.zeros(D_out.size()[0]).to(device) # fake labels = 0
  # print("fake", D_out.squeeze()[:10])
  correct_pred = ((D_out.squeeze() >= 0.5) == labels).float().sum()
  return correct_pred.item()

def sample_input(batch_size=-1):
  if batch_size != -1:
    return torch.normal(mean=torch.zeros((batch_size, GENERATOR_NUM_FEATURES)),
                      std = torch.ones((batch_size, GENERATOR_NUM_FEATURES)))
  else:
    return torch.normal(mean=torch.zeros(GENERATOR_NUM_FEATURES),
                        std = torch.ones(GENERATOR_NUM_FEATURES))
    
# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size = 25
fixed_z = sample_input(sample_size).to(device)

d_optimizer = optim.Adam(D.parameters(), LEARNING_RATE)
g_optimizer = optim.Adam(G.parameters(), LEARNING_RATE)

In [5]:
# helper function for viewing a list of passed in sample images
def view_samples(samples, epoch):
    samples = samples.to('cpu')
    fig, axes = plt.subplots(figsize=(5,5), nrows=5, ncols=samples.size()[0]//5, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples):
        img = img.detach()
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    plt.savefig('graphs/GAN Epoch ' + str(epoch) + '.png')

In [None]:
# keep track of loss and generated, "fake" samples
samples = []
losses = []
real_acc, fake_acc = [], []

print_every = 60000 // BATCH_SIZE + 1

# train the network
D.train()
G.train()
for epoch in range(NUM_EPOCHS):
  real_count, fake_count = 0, 0
  for batch_i, (real_images, _) in enumerate(data_loader):
        
    batch_size = real_images.size(0)
    real_images = real_images.to(device)

    # Generate fake images
    z = sample_input(batch_size).to(device)

    # calculate discriminator accuracy with generated fakes
    D.eval()
    G.eval()
    fake_images = G(z)
    real_count += calc_real_acc(D(real_images))
    fake_count_in_batch = calc_fake_acc(D(fake_images))
    fake_count += fake_count_in_batch
    G.train()
    D.train()
  
    # TRAIN THE DISCRIMINATOR
    d_optimizer.zero_grad()
    
    # 1. Train with real images

    # Compute the discriminator losses on real images 
    # smooth the real labels
    D_real = D(real_images)
    d_real_loss = real_loss(D_real)
    
    # 2. Train with fake images
    
    # Compute the discriminator losses on fake images
    fake_images = G(z)        
    D_fake = D(fake_images)
    d_fake_loss = fake_loss(D_fake)

    d_loss = d_real_loss

    # if (epoch <= 20) or (epoch > 20 and epoch % 2 == 0):
    # if epoch < 20 or fake_count_in_batch <= 0.99 * batch_size:
      # add up loss and perform backprop
    # if batch_i % 2 == 0:
    d_loss += d_fake_loss

    d_loss.backward()
    d_optimizer.step()
  
    # TRAIN THE GENERATOR
    g_optimizer.zero_grad()
    
    # 1. Train with fake images and flipped labels
    
    # Generate fake images
    z = sample_input(batch_size).to(device)
    fake_images = G(z)
    
    # Compute the discriminator losses on fake images 
    # using flipped labels!
    D_fake = D(fake_images)
    g_loss = real_loss(D_fake) # use real loss to flip labels

    # if fake_count_in_batch >= 0.2 * batch_size:
      # perform backprop
    g_loss.backward()
    g_optimizer.step()

    # Print some loss stats
    if batch_i % print_every == 0:
        # print discriminator and generator loss
        print('Epoch [{:5d}/{:5d}]'.format(epoch+1, NUM_EPOCHS))
        print("\td_real_loss: {:6.4f} | d_fake_loss: {:6.4f} | g_loss: {:6.4f}".format(d_real_loss.item(), d_fake_loss.item(), g_loss.item()))
        print("\td_fake_acc: {:6.4f} | d_real_acc: {:6.4f}".format(fake_count/batch_size, real_count/batch_size))

    
    ## AFTER EACH EPOCH##
    # append discriminator loss and generator loss
  losses.append((d_real_loss.item() + d_fake_loss.item(), g_loss.item()))
  real_acc.append(real_count/60000)
  fake_acc.append(fake_count/60000)
  
  # generate and save sample, fake images
  G.eval() # eval mode for generating samples
  samples_z = G(fixed_z)
  samples.append(samples_z)
  if epoch % 5 == 0:
    view_samples(samples_z, epoch)
  G.train() # back to train mode

Epoch [    1/  100]
	d_real_loss: 1.3834 | d_fake_loss: 0.6808 | g_loss: 0.7067
	d_fake_acc: 1.0000 | d_real_acc: 0.0000
Epoch [    2/  100]
	d_real_loss: 2.3757 | d_fake_loss: 1.6749 | g_loss: 0.4232
	d_fake_acc: 0.0391 | d_real_acc: 0.5547
Epoch [    3/  100]
	d_real_loss: 0.7351 | d_fake_loss: 0.3757 | g_loss: 1.6803
	d_fake_acc: 1.0000 | d_real_acc: 0.8672
Epoch [    4/  100]
	d_real_loss: 0.4016 | d_fake_loss: 0.1912 | g_loss: 3.0913
	d_fake_acc: 1.0000 | d_real_acc: 0.9844
Epoch [    5/  100]
	d_real_loss: 0.8315 | d_fake_loss: 0.3841 | g_loss: 1.5413
	d_fake_acc: 1.0000 | d_real_acc: 0.9062
Epoch [    6/  100]
	d_real_loss: 0.4195 | d_fake_loss: 0.1483 | g_loss: 2.6775
	d_fake_acc: 1.0000 | d_real_acc: 0.9609
Epoch [    7/  100]
	d_real_loss: 0.2941 | d_fake_loss: 0.1801 | g_loss: 3.0209
	d_fake_acc: 1.0000 | d_real_acc: 0.9922
Epoch [    8/  100]
	d_real_loss: 0.6539 | d_fake_loss: 0.1514 | g_loss: 3.0117
	d_fake_acc: 0.9922 | d_real_acc: 0.9062
Epoch [    9/  100]
	d_real_loss

In [None]:
fig, (ax1, ax2) = plt.subplots(figsize=(10,4), nrows=1, ncols=2)
losses = np.array(losses)
ax1.plot(losses.T[0], label='Discriminator Loss')
ax1.plot(losses.T[1], label='Generator Loss')
ax1.set_title("Training Stats")
ax2.plot(fake_acc, label='Fake Accuracies')
ax2.plot(real_acc, label='Real Accuracies')
ax2.set_title("Discriminator Accuracy")
ax1.legend()
ax2.legend()

In [None]:
# rows = 10 # split epochs into 10, so 100/10 = every 10 epochs
# cols = 8
# fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

# for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
#     for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
#         img = img.detach().cpu()
#         ax.imshow(img.reshape((28,28)), cmap='Greys_r', vmax=1.0, vmin=-1.0)
#         ax.xaxis.set_visible(False)
#         ax.yaxis.set_visible(False)

In [None]:
# fig, axes = plt.subplots(figsize=(7,3), nrows=1, ncols=cols, sharex=True, sharey=True)
# imgs = []

# for epoch, sample in enumerate(samples):
#   ims = []
#   ims.append(axes[0].text(cols//2, -0.5, "Epoch {}".format(epoch), transform=axes[0].transAxes))
#   for image, ax in zip(sample, axes):
#     image = image.detach().cpu()
#     im = ax.imshow(image.reshape((28,28)), cmap='Greys_r', vmax=1.0, vmin=-1.0, animated=True)
#     ims.append(im)
#     ax.xaxis.set_visible(False)
#     ax.yaxis.set_visible(False)
#   imgs.append(ims)
  
# im_ani = animation.ArtistAnimation(fig, imgs, interval=50, repeat=False, blit=True)
# plt.close()
# HTML(im_ani.to_html5_video())

In [None]:
# G.eval()
# # z = sample_input(25).to(device)
# sample = G(fixed_z)
# view_samples(sample)