In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable

In [8]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)
def denorm(x):
    out = (x+1)/2
    return out.clamp(0,1)

In [3]:
# Image processing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5),
                        std = (0.5,0.5,0.5))])


In [4]:
# MNIST dataset
mnist = datasets.MNIST(root='../data/',
                      train=True,
                      transform=transform,
                      download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                         batch_size=100,
                                         shuffle=True)

In [5]:
# Discriminator
D = nn.Sequential(
    nn.Linear(784,256),
    nn.LeakyReLU(0.2),
    nn.Linear(256,256),
    nn.LeakyReLU(0.2),
    nn.Linear(256,1),
    nn.Sigmoid())
# Generator
G = nn.Sequential(
    nn.Linear(64,256),
    nn.LeakyReLU(0.2),
    nn.Linear(256,256),
    nn.LeakyReLU(0.2),
    nn.Linear(256,784),
    nn.Tanh())
if torch.cuda.is_available():
    D.cuda()
    G.cuda()

In [6]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(),lr=0.0003)

In [14]:
# Start Training
for epoch in range(200):
    for i,(images,_)in enumerate(data_loader):
        # Build mini-batch dataset
        batch_size = images.size(0)
        images = to_var(images.view(batch_size,-1))
        real_labels = to_var(torch.ones(batch_size))
        fake_labels = to_var(torch.zeros(batch_size))
        
        #======================Train the discrimitor=============#
        # Computer loss with real images
        outputs = D(images)
        d_loss_real = criterion(outputs,real_labels)
        real_score = outputs
        
        # compute loss with fake images
        z = to_var(torch.randn(batch_size,64))
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs,fake_labels)
        fake_score = outputs
        
        # backprop + Optimize
        d_loss = d_loss_real + d_loss_fake
        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #===================Train the generator===================#
        # Computer loss with fake images
        z = to_var(torch.randn(batch_size,64))
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs,real_labels)
        
        # Backprop + Optimize
        D.zero_grad()
        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if(i+1)% 300 ==0:
            print('Epoch [%d/%d],Step[%d/%d],d_loss:%.4f, g_loss:%.4f,D(x):%.2f,D(G(z)):%.2f'
                 % (epoch,200,i+1,600,d_loss.data[0],g_loss.data[0],real_score.data.mean(),fake_score.data.mean()))
    # Save real images
    if(epoch+1)==1:
        images = images.view(images.size(0),1,28,28)
        save_image(denorm(images.data),'../images/real_images.png')
    # Save sampled images
    fake_images = fake_images.view(fake_images.size(0),1,28,28)
    save_image(denorm(fake_images.data),'../images/fake_images-%d.png'%(epoch+1))

Epoch [0/200],Step[300/600],d_loss:0.6234, g_loss:2.3670,D(x):0.79,D(G(z)):0.25
Epoch [0/200],Step[600/600],d_loss:0.6449, g_loss:5.0696,D(x):0.77,D(G(z)):0.14
Epoch [1/200],Step[300/600],d_loss:1.6345, g_loss:2.6599,D(x):0.65,D(G(z)):0.47
Epoch [1/200],Step[600/600],d_loss:1.3311, g_loss:1.0783,D(x):0.65,D(G(z)):0.43
Epoch [2/200],Step[300/600],d_loss:0.9720, g_loss:1.1220,D(x):0.73,D(G(z)):0.44
Epoch [2/200],Step[600/600],d_loss:2.1859, g_loss:0.9923,D(x):0.39,D(G(z)):0.50
Epoch [3/200],Step[300/600],d_loss:0.3760, g_loss:2.4420,D(x):0.85,D(G(z)):0.16
Epoch [3/200],Step[600/600],d_loss:2.1222, g_loss:1.7192,D(x):0.47,D(G(z)):0.55
Epoch [4/200],Step[300/600],d_loss:0.6718, g_loss:3.2403,D(x):0.77,D(G(z)):0.22
Epoch [4/200],Step[600/600],d_loss:1.2558, g_loss:1.3773,D(x):0.62,D(G(z)):0.36
Epoch [5/200],Step[300/600],d_loss:1.6368, g_loss:1.1408,D(x):0.49,D(G(z)):0.46
Epoch [5/200],Step[600/600],d_loss:1.0178, g_loss:2.0896,D(x):0.61,D(G(z)):0.28
Epoch [6/200],Step[300/600],d_loss:1.399

Epoch [51/200],Step[300/600],d_loss:1.0255, g_loss:1.5405,D(x):0.67,D(G(z)):0.31
Epoch [51/200],Step[600/600],d_loss:0.9503, g_loss:1.6531,D(x):0.67,D(G(z)):0.29
Epoch [52/200],Step[300/600],d_loss:0.9208, g_loss:1.6262,D(x):0.74,D(G(z)):0.33
Epoch [52/200],Step[600/600],d_loss:0.9338, g_loss:1.5726,D(x):0.69,D(G(z)):0.26
Epoch [53/200],Step[300/600],d_loss:0.9830, g_loss:1.7337,D(x):0.65,D(G(z)):0.27
Epoch [53/200],Step[600/600],d_loss:0.7764, g_loss:1.6641,D(x):0.71,D(G(z)):0.26
Epoch [54/200],Step[300/600],d_loss:0.6787, g_loss:1.7954,D(x):0.80,D(G(z)):0.30
Epoch [54/200],Step[600/600],d_loss:0.8345, g_loss:1.7479,D(x):0.74,D(G(z)):0.31
Epoch [55/200],Step[300/600],d_loss:0.7980, g_loss:1.6103,D(x):0.74,D(G(z)):0.25
Epoch [55/200],Step[600/600],d_loss:0.9030, g_loss:1.9118,D(x):0.69,D(G(z)):0.26
Epoch [56/200],Step[300/600],d_loss:0.7530, g_loss:1.6481,D(x):0.78,D(G(z)):0.29
Epoch [56/200],Step[600/600],d_loss:1.0073, g_loss:1.8752,D(x):0.62,D(G(z)):0.26
Epoch [57/200],Step[300/600]

Epoch [102/200],Step[300/600],d_loss:0.8524, g_loss:1.8159,D(x):0.77,D(G(z)):0.32
Epoch [102/200],Step[600/600],d_loss:0.8424, g_loss:1.2771,D(x):0.71,D(G(z)):0.27
Epoch [103/200],Step[300/600],d_loss:0.8944, g_loss:1.6614,D(x):0.71,D(G(z)):0.29
Epoch [103/200],Step[600/600],d_loss:0.8649, g_loss:1.5770,D(x):0.70,D(G(z)):0.27
Epoch [104/200],Step[300/600],d_loss:0.9164, g_loss:1.7395,D(x):0.70,D(G(z)):0.28
Epoch [104/200],Step[600/600],d_loss:0.7496, g_loss:1.7423,D(x):0.71,D(G(z)):0.22
Epoch [105/200],Step[300/600],d_loss:0.9663, g_loss:1.4489,D(x):0.71,D(G(z)):0.35
Epoch [105/200],Step[600/600],d_loss:0.8657, g_loss:1.5713,D(x):0.69,D(G(z)):0.28
Epoch [106/200],Step[300/600],d_loss:0.9943, g_loss:1.4290,D(x):0.70,D(G(z)):0.35
Epoch [106/200],Step[600/600],d_loss:0.8036, g_loss:1.9721,D(x):0.75,D(G(z)):0.28
Epoch [107/200],Step[300/600],d_loss:0.9004, g_loss:1.7372,D(x):0.68,D(G(z)):0.25
Epoch [107/200],Step[600/600],d_loss:0.9421, g_loss:1.6292,D(x):0.62,D(G(z)):0.23
Epoch [108/200],

Epoch [152/200],Step[300/600],d_loss:0.7658, g_loss:1.9150,D(x):0.78,D(G(z)):0.30
Epoch [152/200],Step[600/600],d_loss:0.8719, g_loss:1.9717,D(x):0.69,D(G(z)):0.24
Epoch [153/200],Step[300/600],d_loss:0.9775, g_loss:1.6470,D(x):0.75,D(G(z)):0.34
Epoch [153/200],Step[600/600],d_loss:0.8524, g_loss:1.6357,D(x):0.69,D(G(z)):0.27
Epoch [154/200],Step[300/600],d_loss:0.8000, g_loss:1.6660,D(x):0.72,D(G(z)):0.23
Epoch [154/200],Step[600/600],d_loss:1.1299, g_loss:1.7950,D(x):0.72,D(G(z)):0.41
Epoch [155/200],Step[300/600],d_loss:0.8779, g_loss:1.8537,D(x):0.67,D(G(z)):0.25
Epoch [155/200],Step[600/600],d_loss:0.7828, g_loss:1.7690,D(x):0.72,D(G(z)):0.24
Epoch [156/200],Step[300/600],d_loss:0.9477, g_loss:1.6534,D(x):0.68,D(G(z)):0.29
Epoch [156/200],Step[600/600],d_loss:0.8099, g_loss:1.5407,D(x):0.77,D(G(z)):0.31
Epoch [157/200],Step[300/600],d_loss:0.9359, g_loss:1.7328,D(x):0.66,D(G(z)):0.24
Epoch [157/200],Step[600/600],d_loss:0.8195, g_loss:1.6322,D(x):0.75,D(G(z)):0.30
Epoch [158/200],

In [15]:
# Save the trained parameters
torch.save(G.state_dict(),'../checkpoint/generator.pkl')
torch.save(D.state_dict(),'../checkpoint/discriminator.pkl')