# 0. Wasserstein GAN

## Reference
- https://www.youtube.com/watch?v=tKQwlf-DAl0
- https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py
- paper : https://arxiv.org/abs/1701.07875

# 1. Library Import

In [1]:
!pip install easydict

Collecting easydict
  Downloading easydict-1.9.tar.gz (6.4 kB)
Building wheels for collected packages: easydict
  Building wheel for easydict (setup.py) ... [?25l- \ done
[?25h  Created wheel for easydict: filename=easydict-1.9-py3-none-any.whl size=6350 sha256=19c0a85eb4e77208b71a7da7da393ac351e4f0fbb7b57a224b4c50e35862f071
  Stored in directory: /root/.cache/pip/wheels/88/96/68/c2be18e7406804be2e593e1c37845f2dd20ac2ce1381ce40b0
Successfully built easydict
Installing collected packages: easydict
Successfully installed easydict-1.9


In [2]:
import easydict
import os
import numpy as np
import math
import itertools

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [3]:
os.makedirs("images", exist_ok=True)

# 2. Parameter Setting

In [4]:
opt = easydict.EasyDict({"n_epochs" : 5, "batch_size" : 64,
                         "lr" : 0.00005, "n_cpu" : 8, 
                         "latent_dim" : 100,"img_size" : 28,
                         "channels" : 1, "n_critic" : 5,
                         "clip_value" : 0.01, "sample_interval" : 400}
                       )

In [5]:
print(opt)

{'n_epochs': 5, 'batch_size': 64, 'lr': 5e-05, 'n_cpu': 8, 'latent_dim': 100, 'img_size': 28, 'channels': 1, 'n_critic': 5, 'clip_value': 0.01, 'sample_interval': 400}


In [6]:
cuda = True if torch.cuda.is_available() else False

In [7]:
print(cuda)

True


In [8]:
img_shape = (opt.channels, opt.img_size, opt.img_size)

# 3. Model

## 3.1 Generator

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def block(in_feat, out_feat, normalize = True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace = True))
            return layers
        
        self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize = False),
                                   *block(128, 256),
                                   *block(256, 512),
                                   *block(512, 1024),
                                   nn.Linear(1024, int(np.prod(img_shape))),
                                   nn.Tanh()
                                  )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

## 3.2 Discriminator

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
                                   nn.LeakyReLU(0.2, inplace = True),
                                   nn.Linear(512, 256),
                                   nn.LeakyReLU(0.2, inplace = True),
                                   nn.Linear(256, 1)
                                  )
        
    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

# 4. Data Loader and Model Training

## 4.1 Model Initialize

In [11]:
generator = Generator()
discriminator = Discriminator()

In [12]:
if cuda:
    generator.cuda()
    discriminator.cuda()

## 4.2 Data Loader

In [13]:
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

--2021-05-03 14:02:17--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz
Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]
--2021-05-03 14:02:17--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/x-gzip]
Saving to: ‘MNIST.tar.gz’

MNIST.tar.gz            [        <=>         ]  33.20M  22.0MB/s    in 1.5s    

2021-05-03 14:02:19 (22.0 MB/s) - ‘MNIST.tar.gz’ saved [34813078]

MNIST/
MNIST/raw/
MNIST/raw/train-labels-idx1-ubyte
MNIST/raw/t10k-labels-idx1-ubyte.gz
MNIST/raw/t10k-labels-idx1-ubyte
MNIST/raw/t10k-images-idx3-ubyte.gz
MNIST/raw/train-images-idx3-ubyte
MNIST/raw/train-labels-idx1-ubyte.gz
MNIST/raw/t10k-images-i

In [14]:
# Condigure data loader
os.makedirs("./", exist_ok = True)
dataloader = torch.utils.data.DataLoader(
                datasets.MNIST(
                    "./",
                    train = True,
                    download = True,
                    transform = transforms.Compose(
                        [transforms.Resize(opt.img_size), transforms.ToTensor(),
                         transforms.Normalize([0.5], [0.5])]
                    )
                ),
                batch_size = opt.batch_size,
                shuffle = True
            )

## 4.3 Optimizers

In [15]:
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr = opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr = opt.lr)

In [16]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## 4.4 Model Training

In [17]:
batches_done = 0

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        # sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        
        # Generate a batch of images
        fake_imgs = generator(z).detach()
        
        # Adversarial Loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
        
        loss_D.backward()
        optimizer_D.step()
        
        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)
            
        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:
            # ---------------
            # Train Generator
            # ---------------
            
            optimizer_G.zero_grad()
            
            # Generate a batch of images
            gen_imgs = generator(z)
            
            # Adversarial Loss
            loss_G = -torch.mean(discriminator(gen_imgs))
            
            loss_G.backward()
            optimizer_G.step()
            
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )
            
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow = 5, normalize = True)
        batches_done += 1

[Epoch 0/5] [Batch 0/938] [D loss: 0.250233] [G loss: -0.010514]
[Epoch 0/5] [Batch 5/938] [D loss: -0.110639] [G loss: -0.011509]
[Epoch 0/5] [Batch 10/938] [D loss: -0.380757] [G loss: -0.015162]
[Epoch 0/5] [Batch 15/938] [D loss: -0.823990] [G loss: -0.028114]
[Epoch 0/5] [Batch 20/938] [D loss: -1.371024] [G loss: -0.056308]
[Epoch 0/5] [Batch 25/938] [D loss: -1.943619] [G loss: -0.097142]
[Epoch 0/5] [Batch 30/938] [D loss: -2.666347] [G loss: -0.159748]
[Epoch 0/5] [Batch 35/938] [D loss: -3.317324] [G loss: -0.247081]
[Epoch 0/5] [Batch 40/938] [D loss: -4.093291] [G loss: -0.352397]
[Epoch 0/5] [Batch 45/938] [D loss: -4.756479] [G loss: -0.491172]
[Epoch 0/5] [Batch 50/938] [D loss: -5.490008] [G loss: -0.627197]
[Epoch 0/5] [Batch 55/938] [D loss: -6.103287] [G loss: -0.818531]
[Epoch 0/5] [Batch 60/938] [D loss: -6.527742] [G loss: -1.053035]
[Epoch 0/5] [Batch 65/938] [D loss: -7.293803] [G loss: -1.248935]
[Epoch 0/5] [Batch 70/938] [D loss: -7.785821] [G loss: -1.525797