In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision.datasets as vdatasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import random
import numpy as np
import os
DATA_PATH = os.environ['DATA_PATH']
%matplotlib inline

https://github.com/soumith/ganhacks

In [2]:
# Hyper Parameters 
INPUT_SIZE = 784
HIDDEN_SIZE = 256
LATENT_SIZE = 100
EPOCH = 200
BATCH_SIZE = 100

In [3]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [4]:
# MNIST Dataset (Images and Labels)

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.5, 0.5, 0.5))])

# MNIST Dataset (Images and Labels)
train_dataset = vdatasets.MNIST(root=DATA_PATH+'MNIST/', 
                            train=True, 
                            transform=transform,
                            download=True)

# Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True)

In [7]:
Generator = nn.Sequential(nn.Linear(LATENT_SIZE,HIDDEN_SIZE),
                                        nn.LeakyReLU(0.2),
                                        nn.Linear(HIDDEN_SIZE,HIDDEN_SIZE),
                                        nn.LeakyReLU(0.2),
                                        nn.Linear(HIDDEN_SIZE,INPUT_SIZE),
                                        nn.Tanh())

Discriminator = nn.Sequential(nn.Linear(INPUT_SIZE,HIDDEN_SIZE),
                                             nn.LeakyReLU(0.2),
                                             nn.Linear(HIDDEN_SIZE,HIDDEN_SIZE),
                                             nn.LeakyReLU(0.2),
                                             nn.Linear(HIDDEN_SIZE,1),
                                             nn.Sigmoid())

In [8]:
LR = 0.0002

loss_function = nn.BCELoss()  
d_optimizer = optim.Adam(Discriminator.parameters(), lr=LR)
g_optimizer = optim.Adam(Generator.parameters(),lr=LR)

In [9]:
for epoch in range(EPOCH):
    for i, (inputs, _) in enumerate(train_loader):
        
        # Generator 학습
        real_img = Variable(inputs.view(-1,28*28))
        real_label = Variable(torch.ones(inputs.size(0)))
        real_preds = Discriminator(real_img)
        
        latent = Variable(torch.randn(inputs.size(0),LATENT_SIZE))
        fake_img = Generator(latent)
        fake_label = Variable(torch.zeros(inputs.size(0)))
        fake_preds = Discriminator(fake_img)
        
        d_loss_1 = loss_function(real_preds.squeeze(1),real_label)
        d_loss_2 = loss_function(fake_preds.squeeze(1),fake_label)
        
        d_loss = d_loss_1 + d_loss_2
        
        Discriminator.zero_grad()
        Generator.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # Discriminator 학습
        latent = Variable(torch.randn(inputs.size(0),LATENT_SIZE))
        fake_img = Generator(latent)
        preds = Discriminator(fake_img)
        
        g_loss = loss_function(preds.squeeze(1),real_label)
        
        Discriminator.zero_grad()
        Generator.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 300 == 0:
            print ('Epoch: [%d/%d], Step: [%d/%d], D Loss: %.4f, G Loss: %.4f' 
                   % (epoch+1, EPOCH, i+1, len(train_dataset)//BATCH_SIZE, d_loss.data[0],g_loss.data[0]))
            
    # 생성한 이미지 샘플 저장
    if (epoch+1) % 10 == 0:
        fake_img = fake_img.view(fake_img.size(0), 1, 28, 28)
        save_image(denorm(fake_img.data), './images/gan_gen_images-%d.png' %(epoch+1))

Epoch: [1/200], Step: [300/600], D Loss: 1.1531, G Loss: 1.5152
Epoch: [1/200], Step: [600/600], D Loss: 0.2564, G Loss: 2.8510
Epoch: [2/200], Step: [300/600], D Loss: 0.0152, G Loss: 5.4922
Epoch: [2/200], Step: [600/600], D Loss: 1.3797, G Loss: 2.3368
Epoch: [3/200], Step: [300/600], D Loss: 0.4456, G Loss: 2.7325
Epoch: [3/200], Step: [600/600], D Loss: 0.3635, G Loss: 2.9107
Epoch: [4/200], Step: [300/600], D Loss: 0.5053, G Loss: 2.6856
Epoch: [4/200], Step: [600/600], D Loss: 1.1031, G Loss: 1.5296
Epoch: [5/200], Step: [300/600], D Loss: 0.4972, G Loss: 2.4258
Epoch: [5/200], Step: [600/600], D Loss: 0.5457, G Loss: 3.1082
Epoch: [6/200], Step: [300/600], D Loss: 0.3753, G Loss: 3.6293
Epoch: [6/200], Step: [600/600], D Loss: 0.2915, G Loss: 3.2657
Epoch: [7/200], Step: [300/600], D Loss: 0.3205, G Loss: 4.4791
Epoch: [7/200], Step: [600/600], D Loss: 0.5279, G Loss: 2.2808
Epoch: [8/200], Step: [300/600], D Loss: 0.3635, G Loss: 4.2301
Epoch: [8/200], Step: [600/600], D Loss:

In [None]:
latent = Variable(torch.randn(1,LATENT_SIZE),volatile=True)
fake_img = Generator(latent)

plt.matshow(np.reshape(denorm(fake_img).data.numpy(), (28, 28)), cmap=plt.get_cmap('gray'))
plt.title("[" + str(epoch) + "] Generated Image\n")
plt.colorbar()
plt.show()