In [1]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [3]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__() 
        self.fc1 = nn.Linear(g_input_dim, 256) #100,256
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2) #256, 512
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2) #512, 1024
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim) #1024, 784
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024) #784,1024
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2) #1024,512
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2) #512, 256
        self.fc4 = nn.Linear(self.fc3.out_features, 1) #256,1
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [4]:
# build network
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)



In [5]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [6]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()
    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1) 
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
    #x_real:[batchsize, 784] y_real:[batchsize,1]

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    #D_output:[batchsize,1]

    # train discriminator on facke
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device))
    #z: [batchsize, 100] x_fake:[batchsize,784] y_fake:[batchsize,1]

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [7]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y = Variable(torch.ones(bs, 1).to(device))
    #z: [batchsize, 100] y:[batchsize,1]

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)
    #G_output: [batchsize,784] D_output: [batchsize,1]

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [9]:
n_epoch = 301
for epoch in range(n_epoch):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
   
    if epoch % 100==0:
        with torch.no_grad():
            test_z = Variable(torch.randn(bs, z_dim).to(device))
            generated = G(test_z)
            save_image(generated.view(generated.size(0), 1, 28, 28), './samples/result_' +str(epoch)+ '.png')

[0/301]: loss_d: 1.143, loss_g: 1.412
[1/301]: loss_d: 1.060, loss_g: 2.083
[2/301]: loss_d: 0.972, loss_g: 1.863
[3/301]: loss_d: 0.874, loss_g: 1.718
[4/301]: loss_d: 0.564, loss_g: 2.707
[5/301]: loss_d: 0.575, loss_g: 2.600
[6/301]: loss_d: 0.662, loss_g: 2.298
[7/301]: loss_d: 0.690, loss_g: 2.229
[8/301]: loss_d: 0.695, loss_g: 2.136
[9/301]: loss_d: 0.753, loss_g: 2.088
[10/301]: loss_d: 0.751, loss_g: 2.116
[11/301]: loss_d: 0.666, loss_g: 2.339
[12/301]: loss_d: 0.745, loss_g: 2.122
[13/301]: loss_d: 0.766, loss_g: 1.977
[14/301]: loss_d: 0.783, loss_g: 1.937
[15/301]: loss_d: 0.872, loss_g: 1.743
[16/301]: loss_d: 0.896, loss_g: 1.685
[17/301]: loss_d: 0.859, loss_g: 1.713
[18/301]: loss_d: 0.877, loss_g: 1.671
[19/301]: loss_d: 0.847, loss_g: 1.777
[20/301]: loss_d: 0.856, loss_g: 1.718
[21/301]: loss_d: 0.897, loss_g: 1.669
[22/301]: loss_d: 0.914, loss_g: 1.596
[23/301]: loss_d: 0.953, loss_g: 1.490
[24/301]: loss_d: 0.975, loss_g: 1.462
[25/301]: loss_d: 1.014, loss_g: 1.