## Importing Packages

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import imageio

## Loading the MNIST Dataset

In [2]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,),(0.5,))
                ])
to_image = transforms.ToPILImage()
trainset = MNIST(root='./data/', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True)

device = 'cuda'

## Creating the Generator & Discriminator Class

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 128
        self.n_out = 784
        self.fc0 = nn.Sequential(
                    nn.Linear(self.n_features, 256),
                    nn.LeakyReLU(0.2)
                    )
        self.fc1 = nn.Sequential(
                    nn.Linear(256, 512),
                    nn.LeakyReLU(0.2)
                    )
        self.fc2 = nn.Sequential(
                    nn.Linear(512, 1024),
                    nn.LeakyReLU(0.2)
                    )
        self.fc3 = nn.Sequential(
                    nn.Linear(1024, self.n_out),
                    nn.Tanh()
                    )
    def forward(self, x):
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = x.view(-1, 1, 28, 28)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_in = 784
        self.n_out = 1
        self.fc0 = nn.Sequential(
                    nn.Linear(self.n_in, 1024),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc1 = nn.Sequential(
                    nn.Linear(1024, 512),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc2 = nn.Sequential(
                    nn.Linear(512, 256),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc3 = nn.Sequential(
                    nn.Linear(256, self.n_out),
                    nn.Sigmoid()
                    )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

## Creating Objects, Optimizer, BCE Loss & Vectors

In [4]:
generator = Generator()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)

g_optim = optim.Adam(generator.parameters(), lr=2e-4)
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4)

g_losses = []
d_losses = []
images = []

criterion = nn.BCELoss()

def noise(n, n_features=128):
    return Variable(torch.randn(n, n_features)).to(device)

def make_ones(size):
    data = Variable(torch.ones(size, 1))
    return data.to(device)

def make_zeros(size):
    data = Variable(torch.zeros(size, 1))
    return data.to(device)

## Creating Functions to Train Generator and Discriminator

In [5]:
def train_discriminator(optimizer, real_data, fake_data):
    n = real_data.size(0)

    optimizer.zero_grad()
    
    prediction_real = discriminator(real_data)
    error_real = criterion(prediction_real, make_ones(n))
    error_real.backward()

    prediction_fake = discriminator(fake_data)
    error_fake = criterion(prediction_fake, make_zeros(n))
    
    error_fake.backward()
    optimizer.step()
    
    return error_real + error_fake

def train_generator(optimizer, fake_data):
    n = fake_data.size(0)
    optimizer.zero_grad()
    
    prediction = discriminator(fake_data)
    error = criterion(prediction, make_ones(n))
    
    error.backward()
    optimizer.step()
    
    return error

## Training the Networks

In [9]:
num_epochs = 250
k = 1
test_noise = noise(64)

generator.train()
discriminator.train()
for epoch in range(num_epochs):
    g_error = 0.0
    d_error = 0.0
    print('epoch :',epoch)
    for i, data in enumerate(trainloader):
        imgs, _ = data
        n = len(imgs)
        for j in range(k):
            fake_data = generator(noise(n)).detach()
            real_data = imgs.to(device)
            d_error += train_discriminator(d_optim, real_data, fake_data)
        fake_data = generator(noise(n))
        g_error += train_generator(g_optim, fake_data)
    img = generator(test_noise).cpu().detach()
    img = make_grid(img)
    images.append(img)
    g_losses.append(g_error/i)
    d_losses.append(d_error/i)
    print('Epoch {}: g_loss: {:.8f} d_loss: {:.8f}\r'.format(epoch, g_error/i, d_error/i))
    
print('Training Finished')
torch.save(generator.state_dict(), 'mnist_generator.pth')

epoch : 0
Epoch 0: g_loss: 1.81372511 d_loss: 0.82537574
epoch : 1
Epoch 1: g_loss: 1.62572193 d_loss: 0.89407593
epoch : 2
Epoch 2: g_loss: 1.62729192 d_loss: 0.89458758
epoch : 3
Epoch 3: g_loss: 1.58510482 d_loss: 0.91025281
epoch : 4
Epoch 4: g_loss: 1.61445010 d_loss: 0.89686686
epoch : 5
Epoch 5: g_loss: 1.49626625 d_loss: 0.95805651
epoch : 6
Epoch 6: g_loss: 1.53261042 d_loss: 0.94009465
epoch : 7
Epoch 7: g_loss: 1.61846125 d_loss: 0.89994645
epoch : 8
Epoch 8: g_loss: 1.46427107 d_loss: 0.96931857
epoch : 9
Epoch 9: g_loss: 1.39981699 d_loss: 0.99399817
epoch : 10
Epoch 10: g_loss: 1.47141731 d_loss: 0.97583044
epoch : 11
Epoch 11: g_loss: 1.50972044 d_loss: 0.95519459
epoch : 12
Epoch 12: g_loss: 1.50925016 d_loss: 0.96258694
epoch : 13
Epoch 13: g_loss: 1.44958758 d_loss: 0.97162688
epoch : 14
Epoch 14: g_loss: 1.35003328 d_loss: 1.02520931
epoch : 15
Epoch 15: g_loss: 1.35375488 d_loss: 1.02467811
epoch : 16
Epoch 16: g_loss: 1.34986687 d_loss: 1.02283359
epoch : 17
Epoch 

Epoch 138: g_loss: 0.88277233 d_loss: 1.28530073
epoch : 139
Epoch 139: g_loss: 0.87827438 d_loss: 1.28813374
epoch : 140
Epoch 140: g_loss: 0.88712174 d_loss: 1.28225029
epoch : 141
Epoch 141: g_loss: 0.88274938 d_loss: 1.28116226
epoch : 142
Epoch 142: g_loss: 0.88291293 d_loss: 1.28964508
epoch : 143
Epoch 143: g_loss: 0.87464017 d_loss: 1.28876066
epoch : 144
Epoch 144: g_loss: 0.86869216 d_loss: 1.29263985
epoch : 145
Epoch 145: g_loss: 0.87818778 d_loss: 1.28494143
epoch : 146
Epoch 146: g_loss: 0.88480002 d_loss: 1.28607011
epoch : 147
Epoch 147: g_loss: 0.88023472 d_loss: 1.29062140
epoch : 148
Epoch 148: g_loss: 0.88072199 d_loss: 1.28811395
epoch : 149
Epoch 149: g_loss: 0.88218266 d_loss: 1.28882337
epoch : 150
Epoch 150: g_loss: 0.87344891 d_loss: 1.29368985
epoch : 151
Epoch 151: g_loss: 0.87263119 d_loss: 1.29173255
epoch : 152
Epoch 152: g_loss: 0.88275421 d_loss: 1.28550565
epoch : 153
Epoch 153: g_loss: 0.87430954 d_loss: 1.28538072
epoch : 154
Epoch 154: g_loss: 0.878

In [10]:
import numpy as np
from matplotlib import pyplot as plt
imgs = [np.array(to_image(i)) for i in images]
imageio.mimsave('progress.gif', imgs)

