In [None]:
!pip install torchvision tensorboardx numpy matplotlib jupyter

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

In [None]:
from utils import Logger

In [None]:
def mnist_data():
  compose = transforms.Compose(
      [transforms.ToTensor(),
       transforms.Normalize((0.5,),(0.5,))
       ])
  out_dir = './dataset'
  return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

  #load data
data = mnist_data()

#Create loader with data to iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
num_batches = len(data_loader)

In [None]:
#This network will take a flattened image as input and return the probability of it belonging to the real dataset
class DiscriminatorNet(torch.nn.Module):
  def __init__(self):
    super(DiscriminatorNet, self).__init__()
    n_features = 784
    n_out = 1

    self.hidden0 = nn.Sequential(
        nn.Linear(n_features, 1024),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3)
    )
    self.hidden1 = nn.Sequential(
        nn.Linear(1024,512),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3)
    )
    self.hidden2 = nn.Sequential(
        nn.Linear(512,256),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.2)
    )
    self.out = nn.Sequential(
        torch.nn.Linear(256,n_out),
        torch.nn.Sigmoid()
    )
  
  def forward(self, x):
    x = self.hidden0(x)
    x = self.hidden1(x)
    x = self.hidden2(x)
    x = self.out(x)
    return x

discriminator = DiscriminatorNet()


In [None]:
def images_to_vectors(images):
  return images.view(images.size(0), 784)

def vectors_to_images(vectors):
  return vectors.view(vectors.size(0), 1, 28, 28)

In [None]:
#This network will take a latent variable vector as input and returns a 784 valued vecctor, which correspons to 28*28 flattened image
class GeneratorNet(torch.nn.Module):
  def __init__(self):
    super(GeneratorNet, self).__init__()
    n_features = 100
    n_out = 784

    self.hidden0 = nn.Sequential(
        nn.Linear(n_features, 256),
        nn.LeakyReLU(0.2)
    )
    self.hidden1 = nn.Sequential(
        nn.Linear(256,512),
        nn.LeakyReLU(0.2)
    )
    self.hidden2 = nn.Sequential(
        nn.Linear(512,1024),
        nn.LeakyReLU(0.2)
    )
    self.out = nn.Sequential(
        nn.Linear(1024,n_out),
        nn.Tanh()
    )

  def forward(self, x):
    x = self.hidden0(x)
    x = self.hidden1(x)
    x = self.hidden2(x)
    x = self.out(x)
    return x

generator = GeneratorNet()

In [None]:
#noise vector
def noise(size):
  n = Variable(torch.randn(size, 100))
  return n

In [None]:
#Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr = 0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr = 0.0002)

In [None]:
#Loss function = Binary Cross Entropy loss
loss = nn.BCELoss()

In [None]:
#real image target = 1 and fake image target = 0 for discriminator
def ones_target(size):
  data = Variable(torch.ones(size,1))
  return data
def zeros_target(size):
  data = Variable(torch.zeros(size,1))
  return data  

In [None]:
def train_discriminator(optimizer, real_data, fake_data):
  N=real_data.size(0)
  optimizer.zero_grad()

  #train on real data
  prediction_real = discriminator(real_data)
  error_real = loss(prediction_real, ones_target(N))
  error_real.backward() #calculating the gradients

  #train on fake data
  prediction_fake = discriminator(fake_data)
  error_fake = loss(prediction_fake, zeros_target(N))
  error_fake.backward()

  #update weights with gradients
  optimizer.step()

  #return error and predictions for real and fake inputs
  return error_real + error_fake, prediction_real, prediction_fake

In [None]:
def train_generator(optimizer, fake_data):
  N = fake_data.size(0)

  #reset gradients
  optimizer.zero_grad()

  #sample noise and generate fake data
  prediction = discriminator(fake_data)

  #calculate error and backpropagate
  error = loss(prediction, ones_target(N))
  error.backward()

  #update weights with gradients
  optimizer.step()

  #return error
  return error

Testing

In [None]:
num_test_samples = 16
test_noise = noise(num_test_samples)

Training

In [None]:
logger = Logger(model_name='VGAN', data_name='MNIST')

num_epochs = 200

for epoch in range(num_epochs):
  for n_batch, (real_batch,_) in enumerate(data_loader):
    N = real_batch.size(0)

    #train discriminator
    real_data = Variable(images_to_vectors(real_batch))

    #generate fake data and detach so that gradients are not calculated for generator
    fake_data = generator(noise(N)).detach()

    d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, real_data, fake_data)


    #train generator
    #generate fake data
    fake_data = generator(noise(N))

    g_error = train_generator(g_optimizer, fake_data)

    logger.log(d_error, g_error, epoch, n_batch, num_batches)

    #display progress
    if (n_batch)% 100 == 0:
      test_images = vectors_to_images(generator(test_noise))
      test_images = test_images.data

      logger.log_images(
          test_images, num_test_samples,
          epoch, n_batch, num_batches
      )

      #display status Logs
      logger.display_status(
          epoch, num_epochs, n_batch, num_batches,
          d_error, g_error, d_pred_real, d_pred_fake
      )