In [0]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Installing required libraries
! pip install torch torchvision tensorboardx



In [0]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

In [0]:
def mnist_data():
  
  compose = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  out_dir = './dataset'
  
  return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True) 

In [0]:
data = mnist_data()

# Creating data loader
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)

num_batches = len(data_loader)

In [0]:
# Discriminator Class

class DiscriminatorNet(torch.nn.Module):
  
  def __init__(self):
    
    super(DiscriminatorNet, self).__init__()
    n_features = 784
    n_out = 1
    
    self.hidden0 = nn.Sequential(nn.Linear(n_features, 1024), nn.LeakyReLU(0.2), nn.Dropout(0.3))
    
    self.hidden1 = nn.Sequential(nn.Linear(1024, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3))
    
    self.hidden2 = nn.Sequential(nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3))
    
    self.out = nn.Sequential(nn.Linear(256, n_out), nn.LeakyReLU(0.2), nn.Dropout(0.3))
    
    
  def forward(self, x):
    
    x = self.hidden0(x)
    x = self.hidden1(x)
    x = self.hidden2(x)
    x = self.out(x)
    return x
    

In [0]:
# Object of Discriminator class initialised

discriminator = DiscriminatorNet()

In [0]:
# Functions to convert images to vectors and vice-versa

def images_to_vectors(images):
  return images.view(images.size(0), 784)

def vectors_to_images(vectors):
  return vectors.view(vectors.size(0), 1, 28, 28)

In [0]:
# Generator Class

class GeneratorNet(torch.nn.Module):
  
  def __init__(self):
    
    super(GeneratorNet, self).__init__()
    n_features = 100
    n_out = 784
    
    self.hidden0 = nn.Sequential(nn.Linear(n_features, 256), nn.LeakyReLU(0.2))
    
    self.hidden1 = nn.Sequential(nn.Linear(256, 512), nn.LeakyReLU(0.2))
    
    self.hidden2 = nn.Sequential(nn.Linear(512, 1024), nn.LeakyReLU(0.2))
    
    self.out = nn.Sequential(nn.Linear(1024, n_out), nn.Tanh())
    
  
  def forward(self, x):
    
    x = self.hidden0(x)
    x = self.hidden1(x)
    x = self.hidden2(x)
    x = self.out(x)
    
    return x


In [0]:
# Generator Object Initialised
generator = GeneratorNet()

In [0]:
# Introducing some random noise sampled from a normal distribution 
# with mean 0 and variance 1

def noise(size):
  
  n = Variable(torch.randn(size, 100))
  
  return n

In [0]:
# Optimizers for both generator and discriminator

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

In [0]:
# Binary Cross Entropy Loss Function

loss = nn.BCELoss()

In [0]:
def ones_target(size):
  
  data = Variable(torch.ones(size, 1))
  
  return data

def zeroes_target(size):
  
  data = Variable(torch.zeros(size, 1))
  
  return data

In [0]:
# Training Discriminator

def  train_discriminator(optimizer, real_data, fake_data):
  
  N = real_data.size(0)
  optimizer.zero_grad()  # Reset Gradients
  
  # Training on Real Data
  
  prediction_real = discriminator(real_data)
  error_real = loss(prediction_real, ones_target(N)) # Calculating Error
  error_real.backward()   # Backpropagation
  
  
  # Training on Fake Data
  
  prediction_fake = discriminator(fake_data)
  error_fake = loss(prediction_fake, zeroes_target(N))  # Calculating Error
  error_fake.backward()    # Backpropagation
  
  
  # Updating Weights
  
  optimizer.step()
  
  return error_real + error_fake, prediction_real, prediction_fake
  

In [0]:
# Training Generator 

def train_generator(optimizer, fake_data):
  
  N = fake_data.size(0)
  
  
  optimizer.zero_grad()  # Reset Gradients
  
  prediction = discriminator(fake_data)
  
  error = loss(prediction, ones_target(N))
  error.backward()
  
  optimizer.step()
  
  return error

In [0]:
# Testing

num_test_samples = 16
test_noise = noise(num_test_samples)

In [0]:
import os
import errno
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
from IPython import display

'''
    TensorBoard Data will be stored in './runs' path
'''


class Logger:

    def __init__(self, model_name, data_name):
        self.model_name = model_name
        self.data_name = data_name

        self.comment = '{}_{}'.format(model_name, data_name)
        self.data_subdir = '{}/{}'.format(model_name, data_name)

        # TensorBoard
        self.writer = SummaryWriter(comment=self.comment)

    def log(self, d_error, g_error, epoch, n_batch, num_batches):

        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()

        step = Logger._step(epoch, n_batch, num_batches)
        self.writer.add_scalar(
            '{}/D_error'.format(self.comment), d_error, step)
        self.writer.add_scalar(
            '{}/G_error'.format(self.comment), g_error, step)

    def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):
        '''
        input images are expected in format (NCHW)
        '''
        if type(images) == np.ndarray:
            images = torch.from_numpy(images)
        
        if format=='NHWC':
            images = images.transpose(1,3)
        

        step = Logger._step(epoch, n_batch, num_batches)
        img_name = '{}/images{}'.format(self.comment, '')

        # Make horizontal grid from image tensor
        horizontal_grid = vutils.make_grid(
            images, normalize=normalize, scale_each=True)
        # Make vertical grid from image tensor
        nrows = int(np.sqrt(num_images))
        grid = vutils.make_grid(
            images, nrow=nrows, normalize=True, scale_each=True)

        # Add horizontal images to tensorboard
        self.writer.add_image(img_name, horizontal_grid, step)

        # Save plots
        self.save_torch_images(horizontal_grid, grid, epoch, n_batch)

    def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)

        # Plot and save horizontal
        fig = plt.figure(figsize=(16, 16))
        plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))
        plt.axis('off')
        if plot_horizontal:
            display.display(plt.gcf())
        self._save_images(fig, epoch, n_batch, 'hori')
        plt.close()

        # Save squared
        fig = plt.figure()
        plt.imshow(np.moveaxis(grid.numpy(), 0, -1))
        plt.axis('off')
        self._save_images(fig, epoch, n_batch)
        plt.close()

    def _save_images(self, fig, epoch, n_batch, comment=''):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(out_dir,
                                                         comment, epoch, n_batch))

    def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):
        
        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()
        if isinstance(d_pred_real, torch.autograd.Variable):
            d_pred_real = d_pred_real.data
        if isinstance(d_pred_fake, torch.autograd.Variable):
            d_pred_fake = d_pred_fake.data
        
        
        print('Epoch: [{}/{}], Batch Num: [{}/{}]'.format(
            epoch,num_epochs, n_batch, num_batches)
             )
        print('Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(d_error, g_error))
        print('D(x): {:.4f}, D(G(z)): {:.4f}'.format(d_pred_real.mean(), d_pred_fake.mean()))

    def save_models(self, generator, discriminator, epoch):
        out_dir = './data/models/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        torch.save(generator.state_dict(),
                   '{}/G_epoch_{}'.format(out_dir, epoch))
        torch.save(discriminator.state_dict(),
                   '{}/D_epoch_{}'.format(out_dir, epoch))

    def close(self):
        self.writer.close()

    # Private Functionality

    @staticmethod
    def _step(epoch, n_batch, num_batches):
        return epoch * num_batches + n_batch

    @staticmethod
    def _make_dir(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

In [0]:
# Creating a Logger instance
logger = Logger(model_name='VGAN', data_name='MNIST')

In [0]:
# Num of epochs to train
num_epochs = 200

In [0]:
for epoch in range(num_epochs):
  
  for n_batch, (real_batch,_) in enumerate(data_loader):
    N = real_batch.size(0)
    
    # Train Discriminator
    real_data = Variable(images_to_vectors(real_batch))
    
    # Generate fake data and detach
    fake_data = generator(noise(N)).detach()
    
    #Train D
    d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, real_data, fake_data)
    
    
    # Train Generator
    fake_data = generator(noise(N))
    
    # Train G
    