**예제 1. Generator 구성하기**

**예제 2. Discriminator 구성하기**

**예제 3. 구성한 Generator 와 Discriminator 를 학습해 새로운 MNIST 이미지 생성**


*   필요한 Library Import




In [None]:
import os
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

from torchvision.utils import save_image
from torchvision import transforms, datasets
import matplotlib.pyplot as plt

*   학습에 사용될 Hyper Parameter 설정

In [None]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using Device:", DEVICE)

epochs = 200
batch_size = 100
lr = 0.0002

Using Device: cuda


*   학습에 필요한 MNIST 데이터셋 다운로드

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)])

In [None]:
mnist_dataset = datasets.MNIST(
    './.data',
    train=True,
    download=True,
    transform=transform)


data_loader = torch.utils.data.DataLoader(
    dataset     = mnist_dataset,
    batch_size  = batch_size,
    shuffle     = True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./.data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/train-images-idx3-ubyte.gz to ./.data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./.data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/train-labels-idx1-ubyte.gz to ./.data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./.data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./.data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./.data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./.data/MNIST/raw



*   Generator 구성

In [None]:
Generator = nn.Sequential(
        nn.Linear(64, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 784),
        nn.Tanh())

*   Discriminator 구성

In [None]:
Discriminator = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1),
        nn.Sigmoid())

*   Loss Function 과 Optimization 정의

In [None]:
Generator = Generator.to(DEVICE)
Discriminator = Discriminator.to(DEVICE)

criterion = nn.BCELoss()
g_optimizer = optim.Adam(Generator.parameters(), lr=lr)
d_optimizer = optim.Adam(Discriminator.parameters(), lr=lr)

*   학습

In [None]:
total_step = len(data_loader)
for epoch in range(epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(DEVICE)
        
        real_label = torch.ones(batch_size, 1).to(DEVICE)# [1,1,1...]
        fake_label = torch.zeros(batch_size, 1).to(DEVICE)# [0.0,0...]
        
        output = Discriminator(images)
        d_loss_real = criterion(output, real_label)
        real_score = output
        
        z = torch.randn(batch_size, 64).to(DEVICE)
        fake_image = Generator(z)
        
        output = Discriminator(fake_image)
        d_loss_fake = criterion(output, fake_label)
        fake_score = output
        
        d_loss = d_loss_real + d_loss_fake

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        fake_image = Generator(z)
        output = Discriminator(fake_image)
        g_loss = criterion(output, real_label)

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
    print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
          .format(epoch, epochs, d_loss.item(), g_loss.item(), 
                  real_score.mean().item(), fake_score.mean().item()))

Epoch [0/20], d_loss: 0.0807, g_loss: 4.3716, D(x): 0.97, D(G(z)): 0.04
Epoch [1/20], d_loss: 0.0865, g_loss: 6.2725, D(x): 0.97, D(G(z)): 0.02
Epoch [2/20], d_loss: 0.6494, g_loss: 3.1902, D(x): 0.80, D(G(z)): 0.18
Epoch [3/20], d_loss: 0.0856, g_loss: 5.3925, D(x): 0.95, D(G(z)): 0.02
Epoch [4/20], d_loss: 0.4355, g_loss: 3.5795, D(x): 0.94, D(G(z)): 0.20
Epoch [5/20], d_loss: 0.1406, g_loss: 4.5352, D(x): 0.96, D(G(z)): 0.05
Epoch [6/20], d_loss: 0.1922, g_loss: 4.9406, D(x): 0.97, D(G(z)): 0.11
Epoch [7/20], d_loss: 0.2253, g_loss: 4.4206, D(x): 0.92, D(G(z)): 0.07
Epoch [8/20], d_loss: 0.2314, g_loss: 6.1533, D(x): 0.88, D(G(z)): 0.01
Epoch [9/20], d_loss: 0.1322, g_loss: 5.3609, D(x): 0.94, D(G(z)): 0.01
Epoch [10/20], d_loss: 0.2361, g_loss: 5.3786, D(x): 0.94, D(G(z)): 0.07
Epoch [11/20], d_loss: 0.0916, g_loss: 8.6695, D(x): 0.98, D(G(z)): 0.03
Epoch [12/20], d_loss: 0.2377, g_loss: 6.2796, D(x): 0.90, D(G(z)): 0.01
Epoch [13/20], d_loss: 0.1816, g_loss: 5.1865, D(x): 0.94, D(



*   학습결과 확인



In [None]:
z = torch.randn(batch_size, 64).to(DEVICE)
fake_image = Generator(z)

for i in range(5):
    fake_out = np.reshape(fake_image.data.cpu().numpy()[i],(28, 28))
    plt.imshow(fake_out, cmap = 'gray')
    plt.show()