**예제 1. Generator 구성하기**

**예제 2. Discriminator 구성하기**

**예제 3. 구성한 Generator 와 Discriminator 학습해 새로운 MNIST 이미지 생성**


*   필요한 Library Import




In [None]:
import os
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

from torchvision.utils import save_image
from torchvision import transforms, datasets
import matplotlib.pyplot as plt

*   학습에 사용될 Hyper Parameter 설정

In [None]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using Device:", DEVICE)

epochs = 200
batch_size = 100
lr = 0.0002

Using Device: cuda


*   학습에 필요한 MNIST 데이터셋 다운로드

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)])

mnist_dataset = datasets.MNIST(
    './.data',
    train=True,
    download=True,
    transform=transform)

data_loader = torch.utils.data.DataLoader(
    dataset     = mnist_dataset,
    batch_size  = batch_size,
    shuffle     = True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./.data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/train-images-idx3-ubyte.gz to ./.data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./.data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/train-labels-idx1-ubyte.gz to ./.data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./.data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./.data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./.data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./.data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./.data/MNIST/raw



*   Generator 구성

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.embed = nn.Embedding(10, 10)

        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh())
        
    def forward(self, z, label):
        c = self.embed(label)
        x = torch.cat([z, c], 1)
        return self.model(x)

*   Discriminator 구성

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed = nn.Embedding(10, 10)

        self.model = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout(0.2), 
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, label):
        c = self.embed(label)
        x = torch.cat([x, c], 1)
        return self.model(x)

*   Loss Function 과 Optimization 정의

In [None]:
Generator = Generator().to(DEVICE)
Discriminator = Discriminator().to(DEVICE)

criterion = nn.BCELoss()
g_optimizer = optim.Adam(Generator.parameters(), lr=lr)
d_optimizer = optim.Adam(Discriminator.parameters(), lr=lr)

*   학습

In [None]:
total_step = len(data_loader)
for epoch in range(epochs):
    for i, (images, label) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(DEVICE)
        
        real_label = torch.ones(batch_size, 1).to(DEVICE)
        fake_label = torch.zeros(batch_size, 1).to(DEVICE)
        
        label = label.to(DEVICE)
        output = Discriminator(images, label)
        d_loss_real = criterion(output, real_label)
        real_score = output
        
        z = torch.randn(batch_size, 100).to(DEVICE)
        g_label = torch.randint(0,10,(batch_size,)).to(DEVICE)
        fake_image = Generator(z, g_label)
        
        output = Discriminator(fake_image, g_label)
        d_loss_fake = criterion(output, fake_label)
        fake_score = output
        
        d_loss = d_loss_real + d_loss_fake

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        fake_image = Generator(z,g_label)
        output = Discriminator(fake_image, g_label)
        g_loss = criterion(output, real_label)

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
    print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
          .format(epoch, epochs, d_loss.item(), g_loss.item(), 
                  real_score.mean().item(), fake_score.mean().item()))

Epoch [0/1], d_loss: 0.8489, g_loss: 1.7388, D(x): 0.74, D(G(z)): 0.25


*   생성하고 싶은 숫자 정한 후 시각화 및 학습결과 확인

In [None]:
item_num = 3
z = torch.randn(1, 100).to(DEVICE)
g_label = torch.full((1,), item_num, dtype=torch.long).to(DEVICE)
fake_image = Generator(z, g_label)

fake_out = np.reshape(fake_image.data.cpu().numpy()
                               [0],(28, 28))
plt.imshow(fake_out, cmap = 'gray')
plt.show()