In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable


In [2]:
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples/mymnist_no_normalize'

In [3]:

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    print("created folder")
#Image processing
# MNIST dataset (images and labels)
train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader (input pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
# Data loader (input pipeline)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [4]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.layer1=nn.Sequential(
            nn.Linear(image_size,hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size,1),
            nn.Sigmoid())
    def forward(self,x):
#         x=x.view(batch_size,-1)
        out=self.layer1(x)
#         out=out.view(batch_size,-1,28,28)
        return out
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.layer1=nn.Sequential(nn.Linear(latent_size,hidden_size),
                                  nn.ReLU(),
                                  nn.Linear(hidden_size,hidden_size),
                                  nn.ReLU(),
                                  nn.Linear(hidden_size,image_size),
                                  nn.Tanh())
    def forward(self,x):
#         x=x.view(batch_size,-1)
        out=self.layer1(x)
#         out=out.view(batch_size,-1,28,28)
        return out

In [5]:
D=discriminator().cuda()
G=generator().cuda()

In [6]:

loss_function=nn.BCELoss()
d_optimizer=torch.optim.Adam(D.parameters(),lr=2e-4)
g_optimizer=torch.optim.Adam(G.parameters(),lr=2e-4)
def denorm(x):
    out=(x+1)/2
    return out.clamp(0,1)
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

total_step=len(train_loader)
for epochs in range(num_epochs):
    for i, (images,_)in enumerate(train_loader):
        images = images.reshape(batch_size, -1)
        images=Variable(images).cuda()
#         print("images shape : ",images.shape)
        
#         images=Variable(images).cuda()
        
        real_labels = torch.ones(batch_size, 1).cuda()
        fake_labels = torch.zeros(batch_size, 1).cuda()
        
        ##discriminator##
        outputs=D.forward(images)
        d_loss_real=loss_function(outputs,real_labels)
        real_score=outputs
        
        z=torch.randn(batch_size,latent_size).cuda()
        fake_images=G.forward(z)
        outputs=D.forward(fake_images)
        d_loss_fake=loss_function(outputs,fake_labels)
        fake_score=outputs
        
        d_loss=d_loss_real+d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        
        ##generator##
        z = torch.randn(batch_size, latent_size).cuda()
        fake_images = G.forward(z)
        outputs = D.forward(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = loss_function(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epochs, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epochs+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(images, os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(fake_images, os.path.join(sample_dir, 'fake_images-{}.png'.format(epochs+1)))
    if epochs%50==0:        
        # Save the model checkpoints 
        torch.save(G.state_dict(), './saved_data/mymnist_no_normalize/G_mnist-{}.ckpt'.format(epochs+1))
        torch.save(D.state_dict(), './saved_data/mymnist_no_normalize/D_mnist-{}.ckpt'.format(epochs+1))    
print("training finished!")

Epoch [0/200], Step [200/600], d_loss: 0.0992, g_loss: 4.3716, D(x): 0.94, D(G(z)): 0.04
Epoch [0/200], Step [400/600], d_loss: 0.1684, g_loss: 5.2993, D(x): 0.94, D(G(z)): 0.09
Epoch [0/200], Step [600/600], d_loss: 0.2402, g_loss: 3.4396, D(x): 0.91, D(G(z)): 0.12
Epoch [1/200], Step [200/600], d_loss: 0.1493, g_loss: 3.5891, D(x): 0.95, D(G(z)): 0.08
Epoch [1/200], Step [400/600], d_loss: 0.3307, g_loss: 2.8284, D(x): 0.88, D(G(z)): 0.15
Epoch [1/200], Step [600/600], d_loss: 1.4584, g_loss: 1.3360, D(x): 0.57, D(G(z)): 0.47
Epoch [2/200], Step [200/600], d_loss: 0.2470, g_loss: 3.1583, D(x): 0.90, D(G(z)): 0.11
Epoch [2/200], Step [400/600], d_loss: 0.3220, g_loss: 2.9576, D(x): 0.86, D(G(z)): 0.11
Epoch [2/200], Step [600/600], d_loss: 0.2651, g_loss: 2.8623, D(x): 0.90, D(G(z)): 0.10
Epoch [3/200], Step [200/600], d_loss: 0.2741, g_loss: 3.2151, D(x): 0.88, D(G(z)): 0.07
Epoch [3/200], Step [400/600], d_loss: 0.5055, g_loss: 2.8829, D(x): 0.82, D(G(z)): 0.12
Epoch [3/200], Step [

Epoch [30/200], Step [600/600], d_loss: 0.0277, g_loss: 7.6945, D(x): 0.99, D(G(z)): 0.01
Epoch [31/200], Step [200/600], d_loss: 0.0724, g_loss: 5.4254, D(x): 0.98, D(G(z)): 0.02
Epoch [31/200], Step [400/600], d_loss: 0.0177, g_loss: 7.8128, D(x): 0.99, D(G(z)): 0.01
Epoch [31/200], Step [600/600], d_loss: 0.0878, g_loss: 6.2918, D(x): 0.99, D(G(z)): 0.04
Epoch [32/200], Step [200/600], d_loss: 0.0593, g_loss: 6.7530, D(x): 0.98, D(G(z)): 0.02
Epoch [32/200], Step [400/600], d_loss: 0.1101, g_loss: 5.8640, D(x): 0.96, D(G(z)): 0.03
Epoch [32/200], Step [600/600], d_loss: 0.0933, g_loss: 6.7390, D(x): 0.95, D(G(z)): 0.01
Epoch [33/200], Step [200/600], d_loss: 0.0828, g_loss: 5.6088, D(x): 0.98, D(G(z)): 0.04
Epoch [33/200], Step [400/600], d_loss: 0.1356, g_loss: 6.4481, D(x): 0.94, D(G(z)): 0.02
Epoch [33/200], Step [600/600], d_loss: 0.1973, g_loss: 6.2196, D(x): 0.92, D(G(z)): 0.01
Epoch [34/200], Step [200/600], d_loss: 0.0899, g_loss: 6.6481, D(x): 0.97, D(G(z)): 0.03
Epoch [34/

Epoch [61/200], Step [400/600], d_loss: 0.1062, g_loss: 5.6986, D(x): 0.95, D(G(z)): 0.03
Epoch [61/200], Step [600/600], d_loss: 0.0433, g_loss: 6.7850, D(x): 0.97, D(G(z)): 0.01
Epoch [62/200], Step [200/600], d_loss: 0.3062, g_loss: 6.1203, D(x): 0.90, D(G(z)): 0.04
Epoch [62/200], Step [400/600], d_loss: 0.1546, g_loss: 7.1739, D(x): 0.94, D(G(z)): 0.03
Epoch [62/200], Step [600/600], d_loss: 0.1525, g_loss: 5.4063, D(x): 0.95, D(G(z)): 0.03
Epoch [63/200], Step [200/600], d_loss: 0.1531, g_loss: 5.4588, D(x): 0.95, D(G(z)): 0.06
Epoch [63/200], Step [400/600], d_loss: 0.1306, g_loss: 4.9130, D(x): 0.97, D(G(z)): 0.07
Epoch [63/200], Step [600/600], d_loss: 0.1379, g_loss: 5.6477, D(x): 0.97, D(G(z)): 0.04
Epoch [64/200], Step [200/600], d_loss: 0.0645, g_loss: 6.4385, D(x): 0.97, D(G(z)): 0.02
Epoch [64/200], Step [400/600], d_loss: 0.1342, g_loss: 6.2625, D(x): 0.94, D(G(z)): 0.03
Epoch [64/200], Step [600/600], d_loss: 0.1517, g_loss: 5.9939, D(x): 0.95, D(G(z)): 0.05
Epoch [65/

Epoch [92/200], Step [200/600], d_loss: 0.3248, g_loss: 4.0983, D(x): 0.92, D(G(z)): 0.11
Epoch [92/200], Step [400/600], d_loss: 0.1834, g_loss: 4.5705, D(x): 0.94, D(G(z)): 0.06
Epoch [92/200], Step [600/600], d_loss: 0.2208, g_loss: 4.3761, D(x): 0.94, D(G(z)): 0.09
Epoch [93/200], Step [200/600], d_loss: 0.2975, g_loss: 5.1934, D(x): 0.89, D(G(z)): 0.04
Epoch [93/200], Step [400/600], d_loss: 0.1438, g_loss: 4.9112, D(x): 0.96, D(G(z)): 0.07
Epoch [93/200], Step [600/600], d_loss: 0.2318, g_loss: 4.9976, D(x): 0.92, D(G(z)): 0.07
Epoch [94/200], Step [200/600], d_loss: 0.2703, g_loss: 5.2297, D(x): 0.92, D(G(z)): 0.08
Epoch [94/200], Step [400/600], d_loss: 0.1827, g_loss: 4.0689, D(x): 0.92, D(G(z)): 0.06
Epoch [94/200], Step [600/600], d_loss: 0.1834, g_loss: 5.1473, D(x): 0.95, D(G(z)): 0.08
Epoch [95/200], Step [200/600], d_loss: 0.2227, g_loss: 4.8565, D(x): 0.93, D(G(z)): 0.09
Epoch [95/200], Step [400/600], d_loss: 0.2940, g_loss: 5.1980, D(x): 0.89, D(G(z)): 0.04
Epoch [95/

Epoch [122/200], Step [400/600], d_loss: 0.2555, g_loss: 3.9821, D(x): 0.96, D(G(z)): 0.15
Epoch [122/200], Step [600/600], d_loss: 0.3594, g_loss: 2.8170, D(x): 0.94, D(G(z)): 0.18
Epoch [123/200], Step [200/600], d_loss: 0.2055, g_loss: 4.2121, D(x): 0.94, D(G(z)): 0.08
Epoch [123/200], Step [400/600], d_loss: 0.2263, g_loss: 3.8649, D(x): 0.94, D(G(z)): 0.11
Epoch [123/200], Step [600/600], d_loss: 0.2471, g_loss: 4.2896, D(x): 0.91, D(G(z)): 0.09
Epoch [124/200], Step [200/600], d_loss: 0.2480, g_loss: 4.8761, D(x): 0.91, D(G(z)): 0.09
Epoch [124/200], Step [400/600], d_loss: 0.4369, g_loss: 3.4862, D(x): 0.84, D(G(z)): 0.10
Epoch [124/200], Step [600/600], d_loss: 0.2967, g_loss: 4.1490, D(x): 0.88, D(G(z)): 0.06
Epoch [125/200], Step [200/600], d_loss: 0.2136, g_loss: 4.2738, D(x): 0.92, D(G(z)): 0.08
Epoch [125/200], Step [400/600], d_loss: 0.4389, g_loss: 5.0767, D(x): 0.89, D(G(z)): 0.10
Epoch [125/200], Step [600/600], d_loss: 0.3236, g_loss: 3.9386, D(x): 0.87, D(G(z)): 0.08

Epoch [152/200], Step [600/600], d_loss: 0.2352, g_loss: 3.7178, D(x): 0.91, D(G(z)): 0.08
Epoch [153/200], Step [200/600], d_loss: 0.4163, g_loss: 3.7907, D(x): 0.84, D(G(z)): 0.10
Epoch [153/200], Step [400/600], d_loss: 0.2737, g_loss: 4.0776, D(x): 0.92, D(G(z)): 0.11
Epoch [153/200], Step [600/600], d_loss: 0.2929, g_loss: 3.8470, D(x): 0.89, D(G(z)): 0.09
Epoch [154/200], Step [200/600], d_loss: 0.3684, g_loss: 3.7482, D(x): 0.86, D(G(z)): 0.09
Epoch [154/200], Step [400/600], d_loss: 0.4129, g_loss: 3.7666, D(x): 0.92, D(G(z)): 0.16
Epoch [154/200], Step [600/600], d_loss: 0.4356, g_loss: 3.5622, D(x): 0.91, D(G(z)): 0.17
Epoch [155/200], Step [200/600], d_loss: 0.4678, g_loss: 4.2068, D(x): 0.82, D(G(z)): 0.10
Epoch [155/200], Step [400/600], d_loss: 0.2659, g_loss: 3.7377, D(x): 0.93, D(G(z)): 0.10
Epoch [155/200], Step [600/600], d_loss: 0.3891, g_loss: 3.8845, D(x): 0.91, D(G(z)): 0.16
Epoch [156/200], Step [200/600], d_loss: 0.2967, g_loss: 3.8455, D(x): 0.92, D(G(z)): 0.14

Epoch [183/200], Step [200/600], d_loss: 0.3621, g_loss: 3.1314, D(x): 0.86, D(G(z)): 0.12
Epoch [183/200], Step [400/600], d_loss: 0.3689, g_loss: 3.9336, D(x): 0.90, D(G(z)): 0.14
Epoch [183/200], Step [600/600], d_loss: 0.3381, g_loss: 3.5289, D(x): 0.88, D(G(z)): 0.10
Epoch [184/200], Step [200/600], d_loss: 0.4093, g_loss: 3.9425, D(x): 0.84, D(G(z)): 0.08
Epoch [184/200], Step [400/600], d_loss: 0.3813, g_loss: 4.1597, D(x): 0.88, D(G(z)): 0.13
Epoch [184/200], Step [600/600], d_loss: 0.2984, g_loss: 4.2489, D(x): 0.93, D(G(z)): 0.14
Epoch [185/200], Step [200/600], d_loss: 0.5106, g_loss: 4.2111, D(x): 0.83, D(G(z)): 0.09
Epoch [185/200], Step [400/600], d_loss: 0.4657, g_loss: 3.8376, D(x): 0.84, D(G(z)): 0.08
Epoch [185/200], Step [600/600], d_loss: 0.3851, g_loss: 3.6006, D(x): 0.91, D(G(z)): 0.16
Epoch [186/200], Step [200/600], d_loss: 0.3437, g_loss: 3.8057, D(x): 0.88, D(G(z)): 0.11
Epoch [186/200], Step [400/600], d_loss: 0.3793, g_loss: 4.0379, D(x): 0.87, D(G(z)): 0.11