In [1]:
import os
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import uuid

In [2]:
# Configurable variables
NUM_EPOCHS = 50
NOISE_DIMENSION = 50
BATCH_SIZE = 128
TRAIN_ON_GPU = True
UNIQUE_RUN_ID = str(uuid.uuid4())
PRINT_STATS_AFTER_BATCH = 50
OPTIMIZER_LR = 0.0002
OPTIMIZER_BETAS = (0.5, 0.999)
GENERATOR_OUTPUT_IMAGE_SHAPE = 28 * 28 * 1

In [3]:
# Speed ups
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
torch.backends.cudnn.benchmark = True

In [5]:
class Generator(nn.Module):
    """
    Vanilla GAN Generator
    """
    def __init__(self,):
        super().__init__()
        self.layers = nn.Sequential(
          # First upsampling
          nn.Linear(NOISE_DIMENSION, 128, bias=False),
          nn.BatchNorm1d(128, 0.8),
          nn.LeakyReLU(0.25),
          # Second upsampling
          nn.Linear(128, 256, bias=False),
          nn.BatchNorm1d(256, 0.8),
          nn.LeakyReLU(0.25),
          # Third upsampling
          nn.Linear(256, 512, bias=False),
          nn.BatchNorm1d(512, 0.8),
          nn.LeakyReLU(0.25),
          # Final upsampling
          nn.Linear(512, GENERATOR_OUTPUT_IMAGE_SHAPE, bias=False),
          nn.Tanh()
        )
    
    def forward(self, x):
        """Forward pass"""
        return self.layers(x)

In [6]:
class Discriminator(nn.Module):
    """
    Vanilla GAN Discriminator
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
          nn.Linear(GENERATOR_OUTPUT_IMAGE_SHAPE, 1024), 
          nn.LeakyReLU(0.25),
          nn.Linear(1024, 512), 
          nn.LeakyReLU(0.25),
          nn.Linear(512, 256), 
          nn.LeakyReLU(0.25),
          nn.Linear(256, 1),
          nn.Sigmoid()
        )

    def forward(self, x):
        """Forward pass"""
        return self.layers(x)