# 0. GAN
- reference : https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py
- paper : https://arxiv.org/abs/1406.2661

# 1. Library Import

In [27]:
# import argparse -> jupyter notebook X
import easydict

import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [11]:
# make images folder
os.makedirs('images', exist_ok = True)

# 2. Parameter Setting

In [28]:
opt = easydict.EasyDict({"n_epochs" : 200,
                         "batch_size" : 64,
                         "lr" : 0.0002,
                         "b1" : 0.5,
                         "b2" : 0.99,
                         "n_cpu" : 8,
                         "latent_dim" : 100,
                         "img_size" : 28,
                         "channels" : 1,
                         "sample_interval" : 400})

In [29]:
print(opt)

{'n_epochs': 200, 'batch_size': 64, 'lr': 0.0002, 'b1': 0.5, 'b2': 0.99, 'n_cpu': 8, 'latent_dim': 100, 'img_size': 28, 'channels': 1, 'sample_interval': 400}


In [30]:
img_shape = (opt.channels, opt.img_size, opt.img_size)

In [32]:
print(img_shape)

(1, 28, 28)


In [33]:
cuda = True if torch.cuda.is_available() else False

In [34]:
print(cuda)

False


# 3. Model

### 3.1 Generator

In [35]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def block(in_feat, out_feat, normalize = True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace = True))
            return layers
        
        self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize = False),
                                   *block(128, 256),
                                   *block(256, 512),
                                   *block(512, 1024),
                                   nn.Linear(1024, int(np.prod(img_shape))),
                                   nn.Tanh()
                                  )
        
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

### 3.2 Discriminator

In [36]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
                                   nn.LeakyReLU(0.2, inplace = True),
                                   nn.Linear(512, 256),
                                   nn.LeakyReLU(0.2, inplace = True),
                                   nn.Linear(256, 1),
                                   nn.Sigmoid()
                                  )
        
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

### 3.3 Loss Function

In [37]:
adversarial_loss = torch.nn.BCELoss()

# 4. Model Training

In [38]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

In [39]:
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [41]:
# data loader
os.makedirs("../data/mnist", exist_ok = True)
dataloader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "../data/mnist",
                train = True,
                download = True,
                transform = transforms.Compose(
                    [transforms.Resize(opt.img_size), transforms.ToTensor(),
                     transforms.Normalize([0.5], [0.5])]
                )
            ),
            batch_size = opt.batch_size,
            shuggle = True
            )