In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable


In [2]:
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples/mymnist_normalize'

In [3]:

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    print("created folder")
#Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])
# MNIST dataset (images and labels)
train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transform,
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                          train=False, 
                                          transform=transform)

# Data loader (input pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
# Data loader (input pipeline)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [4]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.layer1=nn.Sequential(
            nn.Linear(image_size,hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size,1),
            nn.Sigmoid())
    def forward(self,x):
#         x=x.view(batch_size,-1)
        out=self.layer1(x)
#         out=out.view(batch_size,-1,28,28)
        return out
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.layer1=nn.Sequential(nn.Linear(latent_size,hidden_size),
                                  nn.ReLU(),
                                  nn.Linear(hidden_size,hidden_size),
                                  nn.ReLU(),
                                  nn.Linear(hidden_size,image_size),
                                  nn.Tanh())
    def forward(self,x):
#         x=x.view(batch_size,-1)
        out=self.layer1(x)
#         out=out.view(batch_size,-1,28,28)
        return out

In [5]:
D=discriminator().cuda()
G=generator().cuda()

In [6]:

loss_function=nn.BCELoss()
d_optimizer=torch.optim.Adam(D.parameters(),lr=2e-4)
g_optimizer=torch.optim.Adam(G.parameters(),lr=2e-4)
def denorm(x):
    out=(x+1)/2
    return out.clamp(0,1)
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

total_step=len(train_loader)
for epochs in range(num_epochs):
    for i, (images,_)in enumerate(train_loader):
        images = images.reshape(batch_size, -1)
        images=Variable(images).cuda()
#         print("images shape : ",images.shape)
        
#         images=Variable(images).cuda()
        
        real_labels = torch.ones(batch_size, 1).cuda()
        fake_labels = torch.zeros(batch_size, 1).cuda()
        
        ##discriminator##
        outputs=D.forward(images)
        d_loss_real=loss_function(outputs,real_labels)
        real_score=outputs
        
        z=torch.randn(batch_size,latent_size).cuda()
        fake_images=G.forward(z)
        outputs=D.forward(fake_images)
        d_loss_fake=loss_function(outputs,fake_labels)
        fake_score=outputs
        
        d_loss=d_loss_real+d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        
        ##generator##
        z = torch.randn(batch_size, latent_size).cuda()
        fake_images = G.forward(z)
        outputs = D.forward(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = loss_function(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epochs, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epochs+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epochs+1)))
    if epochs%50==0:        
        # Save the model checkpoints 
        torch.save(G.state_dict(), './saved_data/mymnist_normalize/G_mnist-{}.ckpt'.format(epochs+1))
        torch.save(D.state_dict(), './saved_data/mymnist_normalize/D_mnist-{}.ckpt'.format(epochs+1))    
print("training finished!")

KeyboardInterrupt: 