In [1]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt

In [2]:
# Set hyperparameters
EPOCHS = 500
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("다음 장치를 사용합니다 : ", DEVICE)

다음 장치를 사용합니다 :  cuda


In [3]:
# Get dataset
trainset = datasets.FashionMNIST('./.data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ])
)
train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [4]:
# generator
G = nn.Sequential(
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 784),
    nn.Tanh() # -1 < result < 1
)

In [5]:
# discriminator
D = nn.Sequential(
    nn.Linear(784, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    nn.Sigmoid() # sigmoid -> real or fake
)

In [6]:
# Implementation
D = D.to(DEVICE)
G = G.to(DEVICE)

criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

In [None]:
total_step = len(train_loader)
for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
        
        real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
        
        # 판별자에서 진짜에 대한 값 계산
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
        fake_images = G(z)
        
        # 판별자에서 가짜에 대한 값 계산
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # learning discriminator
        d_loss = d_loss_real + d_loss_fake
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # learning generator
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
    print('epoch [{}/{}] d_loss:{:.4f} g_loss:{:.4f} D(x):{:.2f} D(G(z)):{:.2f}'
    .format(epoch, EPOCHS, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))

epoch [0/500] d_loss:0.0479 g_loss:5.1707 D(x):0.98 D(G(z)):0.02
epoch [1/500] d_loss:0.0322 g_loss:4.4349 D(x):0.99 D(G(z)):0.02
epoch [2/500] d_loss:0.0185 g_loss:7.3561 D(x):1.00 D(G(z)):0.01
epoch [3/500] d_loss:0.0155 g_loss:7.7072 D(x):0.99 D(G(z)):0.01
epoch [4/500] d_loss:0.0487 g_loss:5.3951 D(x):0.98 D(G(z)):0.02
epoch [5/500] d_loss:0.1227 g_loss:5.3373 D(x):0.97 D(G(z)):0.04
epoch [6/500] d_loss:0.2757 g_loss:3.0606 D(x):0.91 D(G(z)):0.06
epoch [7/500] d_loss:0.1307 g_loss:4.4753 D(x):0.94 D(G(z)):0.02
epoch [8/500] d_loss:0.1284 g_loss:6.1769 D(x):0.97 D(G(z)):0.02
epoch [9/500] d_loss:0.3016 g_loss:3.0754 D(x):0.95 D(G(z)):0.16
epoch [10/500] d_loss:0.2296 g_loss:3.9922 D(x):0.94 D(G(z)):0.10
epoch [11/500] d_loss:0.2658 g_loss:3.5925 D(x):0.94 D(G(z)):0.09
epoch [12/500] d_loss:0.3978 g_loss:4.3474 D(x):0.95 D(G(z)):0.21
epoch [13/500] d_loss:0.1894 g_loss:4.0140 D(x):0.94 D(G(z)):0.07
epoch [14/500] d_loss:0.3014 g_loss:3.2506 D(x):0.87 D(G(z)):0.09
epoch [15/500] d_los

epoch [124/500] d_loss:1.0807 g_loss:1.2777 D(x):0.67 D(G(z)):0.36
epoch [125/500] d_loss:0.9309 g_loss:1.6918 D(x):0.66 D(G(z)):0.26
epoch [126/500] d_loss:0.8587 g_loss:1.3489 D(x):0.75 D(G(z)):0.34
epoch [127/500] d_loss:0.9655 g_loss:1.7764 D(x):0.70 D(G(z)):0.30
epoch [128/500] d_loss:0.8872 g_loss:1.7294 D(x):0.78 D(G(z)):0.34
epoch [129/500] d_loss:0.9499 g_loss:1.7641 D(x):0.76 D(G(z)):0.36
epoch [130/500] d_loss:0.9633 g_loss:1.3170 D(x):0.71 D(G(z)):0.34
epoch [131/500] d_loss:1.0499 g_loss:1.2954 D(x):0.71 D(G(z)):0.40
epoch [132/500] d_loss:0.9741 g_loss:1.1222 D(x):0.78 D(G(z)):0.41
epoch [133/500] d_loss:0.7912 g_loss:1.5455 D(x):0.73 D(G(z)):0.29
epoch [134/500] d_loss:1.0574 g_loss:1.7179 D(x):0.65 D(G(z)):0.31
epoch [135/500] d_loss:0.8642 g_loss:1.6841 D(x):0.72 D(G(z)):0.29
epoch [136/500] d_loss:0.9691 g_loss:1.4694 D(x):0.73 D(G(z)):0.35
epoch [137/500] d_loss:0.8635 g_loss:1.5115 D(x):0.74 D(G(z)):0.33
epoch [138/500] d_loss:1.0026 g_loss:1.6559 D(x):0.63 D(G(z)):

epoch [247/500] d_loss:0.9713 g_loss:1.3822 D(x):0.64 D(G(z)):0.32
epoch [248/500] d_loss:1.1274 g_loss:1.3741 D(x):0.59 D(G(z)):0.32
epoch [249/500] d_loss:0.9079 g_loss:1.5106 D(x):0.69 D(G(z)):0.32
epoch [250/500] d_loss:1.4995 g_loss:1.2908 D(x):0.54 D(G(z)):0.41
epoch [251/500] d_loss:0.8176 g_loss:1.4299 D(x):0.74 D(G(z)):0.32
epoch [252/500] d_loss:1.3186 g_loss:1.2137 D(x):0.59 D(G(z)):0.40
epoch [253/500] d_loss:1.2776 g_loss:1.1679 D(x):0.58 D(G(z)):0.38
epoch [254/500] d_loss:0.9058 g_loss:1.4568 D(x):0.69 D(G(z)):0.32
epoch [255/500] d_loss:1.3642 g_loss:0.9437 D(x):0.59 D(G(z)):0.46
epoch [256/500] d_loss:1.3102 g_loss:1.1695 D(x):0.61 D(G(z)):0.42
epoch [257/500] d_loss:1.1795 g_loss:1.1952 D(x):0.70 D(G(z)):0.42
epoch [258/500] d_loss:1.1074 g_loss:1.1766 D(x):0.64 D(G(z)):0.38
epoch [259/500] d_loss:0.9971 g_loss:1.4626 D(x):0.66 D(G(z)):0.33
epoch [260/500] d_loss:0.9653 g_loss:1.3296 D(x):0.68 D(G(z)):0.34
epoch [261/500] d_loss:0.9032 g_loss:1.5731 D(x):0.65 D(G(z)):

In [None]:
# visualization
z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
fake_images = G(z)
for i in range(10):
    fake_images_img = np.reshape(fake_images.data.cpu().numpy()[i], (28, 28))
    plt.imshow(fake_imgaes_img, cmap='gray')
    plt.show()