In [12]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

# Image Preprocessing
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# MNIST Dataset
train_dataset = dsets.MNIST(root='../../dataset/mnist',
                            train=True, 
                            transform=transform,
                            download=True)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

In [13]:
# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
        
    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        out = F.sigmoid(self.fc3(h))
        return out

# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 784)
            
    def forward(self, x):
        h = F.leaky_relu(self.fc1(x))
        h = F.leaky_relu(self.fc2(h))
        out = F.tanh(self.fc3(h)) # -1 ~ 1 ? 이미지에 -가 들어갈 일도 있나?
        return out

discriminator = Discriminator()
generator = Generator()


# Loss and Optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)

In [None]:
# Training 
for epoch in range(200):
    for i, (images, _) in enumerate(train_loader):
        # Build mini-batch dataset
        images = images.view(images.size(0), -1)
        images = Variable(images)
        real_labels = Variable(torch.ones(images.size(0))) # 진짜 이미지
        fake_labels = Variable(torch.zeros(images.size(0))) # 가짜 이미지
        
        # Train the discriminator
        discriminator.zero_grad()
        outputs = discriminator(images)
        real_loss = criterion(outputs, real_labels)
        real_score = outputs
        
        noise = Variable(torch.randn(images.size(0), 128))
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())  # .detach?
        fake_loss = criterion(outputs, fake_labels)
        fake_score = outputs
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        # Train the generator 
        generator.zero_grad()
        noise = Variable(torch.randn(images.size(0), 128))
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels) # gradient saturation 막기 위해, 휴리스틱하게 수정
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 300 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' 
                  'D(x): %.2f, D(G(z)): %.2f' 
                  %(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
                    real_score.data.mean(), fake_score.cpu().data.mean()))
            
    # Save the sampled images
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    torchvision.utils.save_image(fake_images.data, 
        '../../outputs/gan_fake_image/fake_samples_%d.png' %(epoch+1))

Epoch [0/200], Step[300/600], d_loss: 0.5339, g_loss: 3.1429, D(x): 0.91, D(G(z)): 0.27
Epoch [0/200], Step[600/600], d_loss: 0.8170, g_loss: 2.3807, D(x): 0.65, D(G(z)): 0.18
Epoch [1/200], Step[300/600], d_loss: 0.5749, g_loss: 2.0113, D(x): 0.84, D(G(z)): 0.30
Epoch [1/200], Step[600/600], d_loss: 1.7680, g_loss: 0.8988, D(x): 0.50, D(G(z)): 0.57
Epoch [2/200], Step[300/600], d_loss: 0.6656, g_loss: 1.4308, D(x): 0.81, D(G(z)): 0.34
Epoch [2/200], Step[600/600], d_loss: 2.1962, g_loss: 0.3265, D(x): 0.50, D(G(z)): 0.76
Epoch [3/200], Step[300/600], d_loss: 1.1974, g_loss: 0.9678, D(x): 0.55, D(G(z)): 0.41
Epoch [3/200], Step[600/600], d_loss: 0.5859, g_loss: 2.5625, D(x): 0.80, D(G(z)): 0.24
Epoch [4/200], Step[300/600], d_loss: 1.2288, g_loss: 0.7887, D(x): 0.59, D(G(z)): 0.48
Epoch [4/200], Step[600/600], d_loss: 1.1990, g_loss: 0.8362, D(x): 0.61, D(G(z)): 0.48
Epoch [5/200], Step[300/600], d_loss: 1.6334, g_loss: 0.9066, D(x): 0.50, D(G(z)): 0.58
Epoch [5/200], Step[600/600], d_