In [17]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image

In [18]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [5]:
batch_size = 256
learning_rate = 0.001
num_epoch = 10

In [6]:
mnist_train = dset.MNIST("./", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test = dset.MNIST("./", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100


  return torch._C._cuda_getDeviceCount() > 0


In [8]:
train_loader = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size, shuffle=True,num_workers=2,drop_last=True)
test_loader = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size, shuffle=False,num_workers=2,drop_last=True)

In [9]:
# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

# Generator 
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

In [10]:
# Device setting
D = D.to(device)
G = G.to(device)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [11]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [19]:
# Start training
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))



Epoch [0/200], Step [200/600], d_loss: 0.8435, g_loss: 2.1398, D(x): 0.68, D(G(z)): 0.25
Epoch [0/200], Step [400/600], d_loss: 0.4781, g_loss: 2.0817, D(x): 0.87, D(G(z)): 0.25
Epoch [0/200], Step [600/600], d_loss: 0.6044, g_loss: 2.7023, D(x): 0.80, D(G(z)): 0.20
Epoch [1/200], Step [200/600], d_loss: 0.3841, g_loss: 2.8319, D(x): 0.88, D(G(z)): 0.15
Epoch [1/200], Step [400/600], d_loss: 0.1320, g_loss: 3.0393, D(x): 0.97, D(G(z)): 0.08
Epoch [1/200], Step [600/600], d_loss: 0.1103, g_loss: 3.5090, D(x): 0.95, D(G(z)): 0.04
Epoch [2/200], Step [200/600], d_loss: 0.2212, g_loss: 4.4116, D(x): 0.90, D(G(z)): 0.06
Epoch [2/200], Step [400/600], d_loss: 0.1708, g_loss: 7.8073, D(x): 0.93, D(G(z)): 0.03
Epoch [2/200], Step [600/600], d_loss: 0.2290, g_loss: 4.1177, D(x): 0.94, D(G(z)): 0.09
Epoch [3/200], Step [200/600], d_loss: 0.2843, g_loss: 3.9767, D(x): 0.92, D(G(z)): 0.08
Epoch [3/200], Step [400/600], d_loss: 0.0629, g_loss: 5.5980, D(x): 0.98, D(G(z)): 0.03
Epoch [3/200], Step [

Epoch [30/200], Step [600/600], d_loss: 0.1274, g_loss: 7.2565, D(x): 0.95, D(G(z)): 0.02
Epoch [31/200], Step [200/600], d_loss: 0.1631, g_loss: 9.0017, D(x): 0.95, D(G(z)): 0.01
Epoch [31/200], Step [400/600], d_loss: 0.0532, g_loss: 6.8048, D(x): 0.99, D(G(z)): 0.02
Epoch [31/200], Step [600/600], d_loss: 0.0670, g_loss: 6.7751, D(x): 0.99, D(G(z)): 0.05
Epoch [32/200], Step [200/600], d_loss: 0.0644, g_loss: 7.8982, D(x): 0.97, D(G(z)): 0.01
Epoch [32/200], Step [400/600], d_loss: 0.0390, g_loss: 6.3682, D(x): 0.98, D(G(z)): 0.01
Epoch [32/200], Step [600/600], d_loss: 0.0699, g_loss: 7.2290, D(x): 0.97, D(G(z)): 0.01
Epoch [33/200], Step [200/600], d_loss: 0.1272, g_loss: 6.6487, D(x): 0.97, D(G(z)): 0.05
Epoch [33/200], Step [400/600], d_loss: 0.0276, g_loss: 6.8470, D(x): 0.99, D(G(z)): 0.02
Epoch [33/200], Step [600/600], d_loss: 0.0544, g_loss: 7.9243, D(x): 0.97, D(G(z)): 0.01
Epoch [34/200], Step [200/600], d_loss: 0.0415, g_loss: 8.3903, D(x): 0.98, D(G(z)): 0.00
Epoch [34/

Epoch [61/200], Step [400/600], d_loss: 0.1049, g_loss: 6.7182, D(x): 0.98, D(G(z)): 0.05
Epoch [61/200], Step [600/600], d_loss: 0.1527, g_loss: 5.8578, D(x): 0.98, D(G(z)): 0.07
Epoch [62/200], Step [200/600], d_loss: 0.0314, g_loss: 6.2382, D(x): 0.98, D(G(z)): 0.01
Epoch [62/200], Step [400/600], d_loss: 0.0863, g_loss: 6.6353, D(x): 0.97, D(G(z)): 0.03
Epoch [62/200], Step [600/600], d_loss: 0.0788, g_loss: 7.4992, D(x): 0.96, D(G(z)): 0.01
Epoch [63/200], Step [200/600], d_loss: 0.1099, g_loss: 6.8768, D(x): 0.96, D(G(z)): 0.03
Epoch [63/200], Step [400/600], d_loss: 0.1221, g_loss: 7.1033, D(x): 0.98, D(G(z)): 0.07
Epoch [63/200], Step [600/600], d_loss: 0.0762, g_loss: 6.5134, D(x): 0.99, D(G(z)): 0.05
Epoch [64/200], Step [200/600], d_loss: 0.0619, g_loss: 7.0185, D(x): 0.98, D(G(z)): 0.02
Epoch [64/200], Step [400/600], d_loss: 0.1104, g_loss: 7.2423, D(x): 0.98, D(G(z)): 0.06
Epoch [64/200], Step [600/600], d_loss: 0.0415, g_loss: 6.3719, D(x): 0.98, D(G(z)): 0.02
Epoch [65/

Epoch [92/200], Step [200/600], d_loss: 0.1667, g_loss: 5.3766, D(x): 0.94, D(G(z)): 0.06
Epoch [92/200], Step [400/600], d_loss: 0.2034, g_loss: 4.8308, D(x): 0.94, D(G(z)): 0.08
Epoch [92/200], Step [600/600], d_loss: 0.1295, g_loss: 4.7230, D(x): 0.96, D(G(z)): 0.06
Epoch [93/200], Step [200/600], d_loss: 0.1644, g_loss: 5.3340, D(x): 0.95, D(G(z)): 0.04
Epoch [93/200], Step [400/600], d_loss: 0.2400, g_loss: 5.1778, D(x): 0.92, D(G(z)): 0.06
Epoch [93/200], Step [600/600], d_loss: 0.2325, g_loss: 5.0930, D(x): 0.89, D(G(z)): 0.03
Epoch [94/200], Step [200/600], d_loss: 0.1328, g_loss: 5.2200, D(x): 0.96, D(G(z)): 0.06
Epoch [94/200], Step [400/600], d_loss: 0.1248, g_loss: 5.6021, D(x): 0.97, D(G(z)): 0.07
Epoch [94/200], Step [600/600], d_loss: 0.1393, g_loss: 5.9512, D(x): 0.94, D(G(z)): 0.03
Epoch [95/200], Step [200/600], d_loss: 0.1727, g_loss: 5.0618, D(x): 0.94, D(G(z)): 0.06
Epoch [95/200], Step [400/600], d_loss: 0.0848, g_loss: 5.8091, D(x): 0.97, D(G(z)): 0.04
Epoch [95/

Epoch [122/200], Step [400/600], d_loss: 0.1825, g_loss: 5.2426, D(x): 0.94, D(G(z)): 0.07
Epoch [122/200], Step [600/600], d_loss: 0.1845, g_loss: 3.5461, D(x): 0.93, D(G(z)): 0.08
Epoch [123/200], Step [200/600], d_loss: 0.3190, g_loss: 5.1424, D(x): 0.95, D(G(z)): 0.12
Epoch [123/200], Step [400/600], d_loss: 0.3631, g_loss: 4.4159, D(x): 0.88, D(G(z)): 0.08
Epoch [123/200], Step [600/600], d_loss: 0.1980, g_loss: 4.3460, D(x): 0.91, D(G(z)): 0.06
Epoch [124/200], Step [200/600], d_loss: 0.3237, g_loss: 4.8942, D(x): 0.89, D(G(z)): 0.06
Epoch [124/200], Step [400/600], d_loss: 0.2985, g_loss: 5.2325, D(x): 0.88, D(G(z)): 0.07
Epoch [124/200], Step [600/600], d_loss: 0.2699, g_loss: 5.3754, D(x): 0.92, D(G(z)): 0.09
Epoch [125/200], Step [200/600], d_loss: 0.2224, g_loss: 4.4776, D(x): 0.89, D(G(z)): 0.04
Epoch [125/200], Step [400/600], d_loss: 0.2778, g_loss: 5.4525, D(x): 0.88, D(G(z)): 0.03
Epoch [125/200], Step [600/600], d_loss: 0.2428, g_loss: 4.4189, D(x): 0.93, D(G(z)): 0.09

Epoch [152/200], Step [600/600], d_loss: 0.2261, g_loss: 4.0652, D(x): 0.93, D(G(z)): 0.09
Epoch [153/200], Step [200/600], d_loss: 0.3377, g_loss: 4.9911, D(x): 0.87, D(G(z)): 0.09
Epoch [153/200], Step [400/600], d_loss: 0.2234, g_loss: 4.2265, D(x): 0.93, D(G(z)): 0.10
Epoch [153/200], Step [600/600], d_loss: 0.2709, g_loss: 4.2836, D(x): 0.94, D(G(z)): 0.12
Epoch [154/200], Step [200/600], d_loss: 0.4216, g_loss: 3.8422, D(x): 0.84, D(G(z)): 0.06
Epoch [154/200], Step [400/600], d_loss: 0.2647, g_loss: 4.2118, D(x): 0.92, D(G(z)): 0.08
Epoch [154/200], Step [600/600], d_loss: 0.5325, g_loss: 3.8726, D(x): 0.95, D(G(z)): 0.23
Epoch [155/200], Step [200/600], d_loss: 0.2695, g_loss: 3.8230, D(x): 0.91, D(G(z)): 0.10
Epoch [155/200], Step [400/600], d_loss: 0.2855, g_loss: 4.5045, D(x): 0.92, D(G(z)): 0.10
Epoch [155/200], Step [600/600], d_loss: 0.2740, g_loss: 4.0927, D(x): 0.89, D(G(z)): 0.08
Epoch [156/200], Step [200/600], d_loss: 0.2595, g_loss: 3.8336, D(x): 0.93, D(G(z)): 0.10

Epoch [183/200], Step [200/600], d_loss: 0.3495, g_loss: 3.9578, D(x): 0.90, D(G(z)): 0.12
Epoch [183/200], Step [400/600], d_loss: 0.3408, g_loss: 4.2706, D(x): 0.85, D(G(z)): 0.07
Epoch [183/200], Step [600/600], d_loss: 0.4588, g_loss: 3.7730, D(x): 0.92, D(G(z)): 0.20
Epoch [184/200], Step [200/600], d_loss: 0.2343, g_loss: 4.2822, D(x): 0.91, D(G(z)): 0.08
Epoch [184/200], Step [400/600], d_loss: 0.3386, g_loss: 4.0082, D(x): 0.94, D(G(z)): 0.15
Epoch [184/200], Step [600/600], d_loss: 0.3573, g_loss: 3.9543, D(x): 0.87, D(G(z)): 0.10
Epoch [185/200], Step [200/600], d_loss: 0.5586, g_loss: 3.8090, D(x): 0.79, D(G(z)): 0.11
Epoch [185/200], Step [400/600], d_loss: 0.4086, g_loss: 4.7703, D(x): 0.86, D(G(z)): 0.08
Epoch [185/200], Step [600/600], d_loss: 0.2836, g_loss: 3.7219, D(x): 0.94, D(G(z)): 0.12
Epoch [186/200], Step [200/600], d_loss: 0.4046, g_loss: 4.2986, D(x): 0.85, D(G(z)): 0.11
Epoch [186/200], Step [400/600], d_loss: 0.3105, g_loss: 3.5741, D(x): 0.87, D(G(z)): 0.07

In [13]:
# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')