In [41]:
# prerequisites
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [42]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [43]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [44]:
# build network
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

In [45]:
G

Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [46]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [47]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [48]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [49]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y = Variable(torch.ones(bs, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [50]:
n_epoch = 101
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    with torch.no_grad():
        test_z = Variable(torch.randn(bs, z_dim).to(device))
        generated = G(test_z)
        save_image(generated.view(generated.size(0), 1, 28, 28), 'samples/sample_%d.png' % epoch, nrow=10, normalize=True)
    
    if (epoch-1) % 50 == 0:
        torch.save(G.state_dict(), os.path.join('samples', 'G--{}.ckpt'.format(epoch+1)))
        torch.save(D.state_dict(), os.path.join('samples', 'D--{}.ckpt'.format(epoch+1)))

[1/101]: loss_d: 0.801, loss_g: 3.637
[2/101]: loss_d: 0.704, loss_g: 4.215
[3/101]: loss_d: 0.851, loss_g: 2.454
[4/101]: loss_d: 0.587, loss_g: 2.946
[5/101]: loss_d: 0.468, loss_g: 3.058
[6/101]: loss_d: 0.420, loss_g: 3.313
[7/101]: loss_d: 0.504, loss_g: 2.880
[8/101]: loss_d: 0.581, loss_g: 2.549
[9/101]: loss_d: 0.552, loss_g: 2.650
[10/101]: loss_d: 0.639, loss_g: 2.361
[11/101]: loss_d: 0.699, loss_g: 2.164
[12/101]: loss_d: 0.707, loss_g: 2.093
[13/101]: loss_d: 0.791, loss_g: 1.957
[14/101]: loss_d: 0.804, loss_g: 1.913
[15/101]: loss_d: 0.778, loss_g: 1.933
[16/101]: loss_d: 0.794, loss_g: 1.888
[17/101]: loss_d: 0.798, loss_g: 1.938
[18/101]: loss_d: 0.846, loss_g: 1.781
[19/101]: loss_d: 0.870, loss_g: 1.702
[20/101]: loss_d: 0.873, loss_g: 1.669
[21/101]: loss_d: 0.877, loss_g: 1.692
[22/101]: loss_d: 0.877, loss_g: 1.670
[23/101]: loss_d: 0.909, loss_g: 1.628
[24/101]: loss_d: 0.938, loss_g: 1.536
[25/101]: loss_d: 0.981, loss_g: 1.478
[26/101]: loss_d: 0.961, loss_g: 1

In [51]:
with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')


In [80]:
ori_sample = Variable(torch.randn(bs, z_dim).to(device))

In [84]:
sample = ori_sample.clone()
sample[:,0:100:4] -= 1
sample[:,1:100:5] += 1
for ii in range(0, 21):
    sample[:,0:50:4] += 0.1
    sample[:,1:51:5] -= 0.1
    print(torch.mean(sample))
    generated = G(sample)
    save_image(generated.view(generated.size(0), 1, 28, 28)[13,:],
                       'samples/pertubation_test_' + str(ii) + '.png', nrow=10)




tensor(-0.0597)
tensor(-0.0567)
tensor(-0.0537)
tensor(-0.0507)
tensor(-0.0477)
tensor(-0.0447)
tensor(-0.0417)
tensor(-0.0387)
tensor(-0.0357)
tensor(-0.0327)
tensor(-0.0297)
tensor(-0.0267)
tensor(-0.0237)
tensor(-0.0207)
tensor(-0.0177)
tensor(-0.0147)
tensor(-0.0117)
tensor(-0.0087)
tensor(-0.0057)
tensor(-0.0027)
tensor(0.0003)
