In [14]:
from IPython import display
import os 
import numpy as np

import easydict
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
from torchvision.utils import save_image

import torchvision.utils as vutils

In [2]:
os.makedirs("images", exist_ok=True)

opt = easydict.EasyDict({
    "n_epochs" : 50,
    "batch_size" : 64,
    "lr" : 0.0002,
    "b1" : 0.5,
    "b2" : 0.999,
    "n_cpu" : 8,
    "latent_dim" : 100,
    "img_size" : 28,
    "channels" : 1,
    "sample_interval" : 400
})


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [9]:
img_shape = (opt.channels, opt.img_size, opt.img_size)
img_shape

(1, 28, 28)

### Generator & Discriminator Class

![python image](jupyer_images_file\GAN.png)

In [5]:

class Generator(nn.Module):
    """
    Noise to fake image 
    """
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = torch.flatten(img, 1)
        validity = self.model(img_flat)

        return validity


### Loss function
$\begin{align}
min_{g}max_{d}V(d,g) &= \mathbf{E}_{x \sim}{p_{data(x)}}[log D(x)] +   \mathbf{E}_{z \sim}{p_{z(z)}{}}[log(1- D(g(z)))] \\
&= \mathbf{E}_{x \sim}{p_{data(x)}}[log D(x)] +   \mathbf{E}_{x \sim}{p_{g(x)}{}}[log D(x)] \\
&= \int_{x} p_{data(x)}log(D(x)) + p_{g(x)}log(1-D(x)) dx
\end{align}$

$where D(x) \in [0,1] $ 

In [6]:
adversarial_loss = nn.BCELoss() # Dont using cuda ! it is not parameters.
generator = Generator().to(device)
discriminator = Discriminator().to(device)

### Optimization function

- GAN 논문에서는 SGD를 사용함.
- [Adam](https://arxiv.org/abs/1412.6980) 링크 참고.
- [Adam 설명](https://mjgim.me/2018/01/22/adam.html)

In [7]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

### Setting Fixed noise & Dataloader

In [11]:
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

fixed_noise = torch.randn(opt.batch_size, 100, device=device)

### Run model

In [19]:
for epoch in range(opt.n_epochs):
    progress = enumerate(dataloader)
    for i, data in progress:
        input_data = data[0].to(device)
        
        real_label = torch.full((input_data.size(0), 1), 1, dtype = input_data.dtype, device = device)
        fake_label = torch.full((input_data.size(0), 1), 0, dtype = input_data.dtype, device = device)
        
        #################################
        # D Netokork
        #################################
        
        discriminator.zero_grad()
        output = discriminator(input_data)
        
        D_error_real = adversarial_loss(output, real_label) #Page 4. check out algorithm 1. update the discriminator by Ascending its stochastic gradient.
        D_error_real.backward()
        D_x = output.mean().item()
        
        #train with fake data
        noise_z = torch.randn(input_data.size(0), opt.latent_dim, device = device)
        fake = generator(noise_z)
        output = discriminator(fake.detach())
        
        D_error_fake = adversarial_loss(output, fake_label) #page 4. Check out lagorithm 2. update the generator by decending its stochastic gradient.
        D_error_fake.backward()
        D_G_z1 = output.mean().item()
        
        D_error = D_error_real + D_error_fake
        optimizer_D.step()
        
        
        #################################
        # G Netokork
        #################################
        generator.zero_grad()
        output = discriminator(fake)
        G_error = adversarial_loss(output, real_label)        
        G_error.backward()
        D_G_z2 = output.mean().item()
        optimizer_G.step()
        
        #################################
        # Fixed noise 
        #################################
        fixed_image = generator(fixed_noise)
        
        
        if epoch % 10 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), D_error.item(), G_error.item())
            )
            
            batch_done = epoch * len(dataloader) + i
            if batch_done % opt.sample_interval == 0:
                save_image(fake.data[:25], os.path.join("images", f"train_images{epoch + 1}.bmp"))
                #save_image(fake.data[:25], "images/%d.png" % batch_done, nrow = 5, normalize = True)
                save_image(fixed_image.data[:25].detach(), os.path.join("images", f"fixd_fake_images{epoch + 1}.bmp"))
                
        

[Epoch 0/50] [Batch 0/938] [D loss: 0.880350] [G loss: 2.157125]
[Epoch 0/50] [Batch 1/938] [D loss: 0.562746] [G loss: 2.543644]
[Epoch 0/50] [Batch 2/938] [D loss: 0.941509] [G loss: 1.649195]
[Epoch 0/50] [Batch 3/938] [D loss: 0.700183] [G loss: 2.922287]
[Epoch 0/50] [Batch 4/938] [D loss: 0.920608] [G loss: 1.098843]
[Epoch 0/50] [Batch 5/938] [D loss: 0.938781] [G loss: 3.296962]
[Epoch 0/50] [Batch 6/938] [D loss: 0.645916] [G loss: 2.342561]
[Epoch 0/50] [Batch 7/938] [D loss: 0.838843] [G loss: 1.406083]
[Epoch 0/50] [Batch 8/938] [D loss: 0.886356] [G loss: 2.786447]
[Epoch 0/50] [Batch 9/938] [D loss: 0.697495] [G loss: 1.496172]
[Epoch 0/50] [Batch 10/938] [D loss: 0.774830] [G loss: 2.705410]
[Epoch 0/50] [Batch 11/938] [D loss: 0.732244] [G loss: 1.518692]
[Epoch 0/50] [Batch 12/938] [D loss: 0.722799] [G loss: 2.808869]
[Epoch 0/50] [Batch 13/938] [D loss: 0.830408] [G loss: 1.726742]
[Epoch 0/50] [Batch 14/938] [D loss: 0.839554] [G loss: 1.540067]
[Epoch 0/50] [Batch 