In [2]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

from utils import Logger

## Helper Functions

In [None]:
def images_to_vector(image):
    return image.view(images.size(0), 784)

def vector_to_images(vector):
    return vector.view(vector.size(0), 1, 28, 28)

def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size, 100))
    return n

def ones_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
    return data

def zeros_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
    return data

## Importing and Transforming the Data

In [5]:
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        ])
    out_dir = './dataset'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

# Load data
data = mnist_data()

# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)

# Num batches
num_batches = len(data_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset\MNIST\raw\train-images-idx3-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████▉| 9895936/9912422 [02:47<00:03, 4615.15it/s]

Extracting ./dataset\MNIST\raw\train-images-idx3-ubyte.gz to ./dataset\MNIST\raw



0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset\MNIST\raw\train-labels-idx1-ubyte.gz



  0%|                                                                                        | 0/28881 [00:01<?, ?it/s][A
 28%|████████████████████▉                                                     | 8192/28881 [00:01<00:00, 30421.66it/s][A
 57%|█████████████████████████████████████████▍                               | 16384/28881 [00:01<00:00, 28571.28it/s][A
 85%|██████████████████████████████████████████████████████████████           | 24576/28881 [00:02<00:00, 28504.40it/s][A

0it [00:00, ?it/s][A[A

Extracting ./dataset\MNIST\raw\train-labels-idx1-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset\MNIST\raw\t10k-images-idx3-ubyte.gz




  0%|                                                                                      | 0/1648877 [00:01<?, ?it/s][A[A

  1%|▋                                                                      | 16384/1648877 [00:01<00:54, 30108.23it/s][A[A

  1%|█                                                                      | 24576/1648877 [00:02<00:54, 29550.20it/s][A[A

  2%|█▍                                                                     | 32768/1648877 [00:02<00:59, 27259.12it/s][A[A

  2%|█▊                                                                     | 40960/1648877 [00:02<00:58, 27260.54it/s][A[A

  3%|██                                                                     | 49152/1648877 [00:03<00:59, 26860.35it/s][A[A

  3%|██▍                                                                    | 57344/1648877 [00:03<00:54, 29271.30it/s][A[A

  4%|██▊                                                                    | 65536/1648877 [00:04<01:12, 218

Extracting ./dataset\MNIST\raw\t10k-images-idx3-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz





8192it [00:01, 4273.89it/s]                                                                                            [A[A[A


Extracting ./dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./dataset\MNIST\raw
Processing...
Done!




1654784it [02:19, 162231.86it/s]                                                                                       [A[A

## Defining the Discriminator Network

In [8]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        INPUT_FEATURES = 784
        OUTPUT_NODES   = 1
        
        self.layer1 = nn.Sequential(
                        nn.Linear(INPUT_FEATURES, 1024),
                        nn.LeakyReLU(0.2),
                        nn.Dropout(0.3)
                        )
        
        self.layer2 = nn.Sequential(
                        nn.Linear(1024, 512),
                        nn.LeakyReLU(0.2),
                        nn.Dropout(0.3)
                        )
        
        self.layer3 = nn.Sequential(
                        nn.Linear(512, 256),
                        nn.LeakyReLU(0.2),
                        nn.Dropout(0.3)
                        )
        
        self.output = nn.Sequential(
                        nn.Linear(256, OUTPUT_NODES),
                        nn.Sigmoid()
                        )
        
        def forward(x):
            x = self.layer1(x)
            x = self.layer2(x)            
            x = self.layer3(x)            
            x = self.output(x)
            return x

## Defining the Generative Network

In [None]:
class Generative(torch.nn.Module):
    def __init__(self):
        super(Generative, self).__init__()
        INPUT_FEATURES = 100
        OUTPUT_NODES   = 784
        
        self.layer1 = nn.Sequential(
                        nn.Linear(INPUT_FEATURES, 256),
                        nn.LeakyReLU(0.2)
                        )
        
        self.layer2 = nn.Sequential(
                        nn.Linear(256, 512),
                        nn.LeakyReLU(0.2)
                        )
        
        self.layer3 = nn.Sequential(
                        nn.Linear(512, 1024),
                        nn.LeakyReLU(0.2)
                        )
        
        self.output = nn.Sequential(
                        nn.Linear(1024, OUTPUT_NODES),
                        nn.Tanh()
                        )
        
        def forward(x):
            x = self.layer1(x)
            x = self.layer2(x)            
            x = self.layer3(x)            
            x = self.output(x)
            return x