- reference
    - https://github.com/dreamgonfly/GAN-tutorial/blob/master/GAN.ipynb

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as utils
from torchvision import datasets, transforms
import torch.nn.functional as F

from matplotlib import pyplot as plt
import random
import numpy as np

is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

# 0~1로 standardize -> G에서 tanh 써야함
transform = transforms.Compose([
    transforms.ToTensor(), # change data as tensor
    transforms.Normalize(mean=[0.5], std=[0.5])]) # 0~1의 pixel값(grey scale)을 -1~1로 바꾼다

# MNIST dataset
train_data = datasets.MNIST(root='data/', train=True, transform=transform, download=True)
test_data  = datasets.MNIST(root='data/', train=False, transform=transform, download=True)


batch_size = 100
dim_z = 100
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size, shuffle=True)

In [2]:
example_mini_batch_img, example_mini_batch_label  = next(iter(train_data_loader))

In [3]:
# input notse generator for Generator
def input_noise_generator(batch_size, dim):
    return torch.randn(batch_size, dim, device=device)

In [4]:
def visualize_generated_mnist(x, epoch):
    # sample one
    sample_index = random.randint(0,batch_size-1)
    
    # reshape
    x_reshape = x.data[sample_index].numpy().squeeze()
    min_x = x_reshape.min()
    max_x = x_reshape.max()
    
    # normalize
    x_normalize = (x_reshape - min_x)/(max_x - min_x)
    
    # save image
    plt.imsave(f'img/G_epoch_{epoch}.png', x_normalize)

In [5]:
# define generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(dim_z, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28*28),
            nn.Tanh() #since normalized to -1~1
        )
    
    def forward(self, x):
        # to make visualization easy, change dimension
        return self.model(x).view(-1,1,28,28)
        

In [6]:
# define discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # reshape dim (batch_size,1,28,28) ->( batch_size,28*28)
        x = x.view(-1, 28*28)
        return self.model(x)

In [7]:
# make instance and allocate device
G = Generator()
D = Discriminator()

if is_cuda:
    G.cuda()
    D.cuda()

In [8]:
# loss function
criterion = nn.BCELoss()

G_optimizer = optim.SGD(G.parameters(), lr = 1e-3, momentum=0.8)
D_optimizer = optim.SGD(G.parameters(), lr = 1e-3, momentum=0.8)

In [9]:
for epoch in range(10):
    D_loss = []
    G_loss = []
    
    for mini_batch_img, mini_batch_label in train_data_loader:
        D_loss_batch = []
        G_loss_batch = []
        
        # convert data as tensor
        mini_batch = Variable(mini_batch_img)
        
        # make label
        target_real = Variable(torch.ones(batch_size, 1))
        target_fake = Variable(torch.zeros(batch_size, 1))
        
        # push all tensor to cuda
        if is_cuda:
            mini_batch = mini_batch.cuda()
            target_real = target_real.cuda()
            target_fake = target_fake.cuda()
        
        # Generate fake image
        random_noise = input_noise_generator(batch_size, dim_z)
        fake_batch = G(random_noise) 
        
        # Discriminate real and fake images
        D_result_real = D(mini_batch)
        D_result_fake = D(fake_batch)
        
        # calculate loss for discriminator
        D_loss_real = criterion(D_result_real, target_real)
        D_loss_fake = criterion(D_result_fake, target_fake)
        D_loss_total = D_loss_real + D_loss_fake
        
        # backprop discriminator
        D.zero_grad()
        D_loss_total.backward()
        D_optimizer.step()
        
        # calculate loss for generator
        ## loss gets lower if discriminator was fooled
        random_noise = input_noise_generator(batch_size, dim_z)
        fake_batch = G(random_noise) 
        D_result_fake = D(fake_batch)
        G_loss_total = criterion(D_result_fake, target_real)
        
        # backprop generator
        G.zero_grad()
        G_loss_total.backward()
        G_optimizer.step()
        
        G_loss_batch.append(G_loss_total.data.item())
        D_loss_batch.append(D_loss_total.data.item())
        
    # print error
    print(f"----------epoch {epoch}----------")
    print(f"Generator loss: {np.mean(G_loss_batch)}")
    print(f"Discriminator loss: {np.mean(D_loss_batch)}")
    
    # sample out
    visualize_generated_mnist(fake_batch, epoch)

----------epoch 0----------
Generator loss: 0.7086819410324097
Discriminator loss: 1.3958896398544312
----------epoch 1----------
Generator loss: 0.7062671184539795
Discriminator loss: 1.3908379077911377
----------epoch 2----------
Generator loss: 0.7030912041664124
Discriminator loss: 1.3893771171569824
----------epoch 3----------
Generator loss: 0.7001062631607056
Discriminator loss: 1.3905487060546875
----------epoch 4----------
Generator loss: 0.6956173777580261
Discriminator loss: 1.4048317670822144
----------epoch 5----------
Generator loss: 0.6904615163803101
Discriminator loss: 1.4054063558578491
----------epoch 6----------
Generator loss: 0.6849323511123657
Discriminator loss: 1.410360336303711
----------epoch 7----------
Generator loss: 0.6752581596374512
Discriminator loss: 1.4230632781982422
----------epoch 8----------
Generator loss: 0.6600431799888611
Discriminator loss: 1.4275026321411133
----------epoch 9----------
Generator loss: 0.648216187953949
Discriminator loss: 1