# Improved Training of Wasserstein GANs


In [None]:
import numpy as np
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
from collections import OrderedDict
import os

to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([T.ToTensor(),T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
class Parser():
    #hyperparameters
    def __init__(self):
        #image setting
        self.n_epoch = 200
        self.batch_size = 64
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.latent_dim = 64
        self.img_size = 28
        self.channels = 1
        self.n_critic = 5
        self.clip_value = 0.01
        self.sample_interval = 50
        self.lambda_gp = 10
        self.show_freq = 50
        self.model_path = './Model/'
        self.img_path = './Image/'  
        
opt = Parser()

img_shape = (opt.channels, opt.img_size, opt.img_size)

def show_img(model,input):
    
    show_fn = T.Compose([T.ToPILImage()])
    plt.imshow(show_fn(model(input)[0]))
    
def count_params(model):
    
    param_count = np.sum([np.prod(p.size()) for p in model.parameters()])
    print('Number of parameters: ',param_count)

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    alpha = torch.rand(real_samples.size(0), 1, 1, 1)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    out_shape = torch.ones(real_samples.shape[0], 1)
    gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
                              grad_outputs=out_shape, create_graph=True, retain_graph=True,
                              only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# If we make it in a class, we need to load the data everytime
# class WGAN(nn.Module):
#     def __init__(self):
#         super(WGAN, self).__init__()
#         self.g = Generator()
#         self.d = Discriminator()


In [None]:
# Model Classes

class LayerNorm(nn.Module):
    
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        
        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) # num_featurs, depth
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        mean = x.view(x.size(0), -1).mean(1).view(shape)
        std = x.view(x.size(0), -1).std(1).view(shape)
        y = (x - mean) / (std + self.eps)
        if self.affine:
            a_shape = [1, -1] + [1] * (x.dim() - 2)
            y = self.gamma.view(a_shape) * y + self.beta.view(a_shape)
        return y

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten,self).__init__()
    def forward(self,x):
        return x.view(x.size(0),-1)

class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, k=5 , s=1 , p=0 ,d=1 ,normalize=True, deconv=False):
            # Contain "Linear -> Layernorm -> LeakyReLU"
            if deconv:
                layers= nn.ModuleList([nn.ConvTranspose2d(in_feat,out_feat, k, s, p, dilation = d  )])
            else:   
                layers = nn.ModuleList([nn.Conv2d(in_feat, out_feat, k, s, p, d)])
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(1, 64,5,2,2), 
            *block(64,128,3,1,1),
            *block(128,256,3,2,1),
            *block(256,256,3,1,2,2),
            *block(256,256,3,1,8,8),
            *block(256,128,4,2,1,deconv=True),
            *block(128,64,3,1,0),
            nn.Conv2d(64,1,3,1,0),
            nn.Tanh(),
            )

    def forward(self, z):
        self.img = self.model(z)
        return self.img
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
          
        def block(in_feat, out_feat, k=5 , s=1 , p=0 ,d=1 ,normalize=True, deconv=False):
            # Contain "Linear -> Layernorm -> LeakyReLU"
            if deconv:
                layers= nn.ModuleList([nn.ConvTranspose2d(in_feat,out_feat, k, s, p, dilation = d  )])
            else:   
                layers = nn.ModuleList([nn.Conv2d(in_feat, out_feat, k, s, p, d)])
            if normalize:
                layers.append(LayerNorm(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(1,64,5,2,1),
            *block(64,128,5,2,1),
            *block(128,256,5,2,1),
            nn.Conv2d(256,1,2,1),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity.view(-1,1).squeeze(1)

In [None]:
opt.sample_interval = 200
opt.batch_size = 64
opt.show_freq = 50
opt.n_epoch = 200

if not os.path.exists(opt.model_path):
    os.makedirs(opt.model_path)
if not os.path.exists(opt.img_path):
    os.makedirs(opt.img_path)
    
dataloader = torch.utils.data.DataLoader(datasets.MNIST('./data/mnist',
                                                        train=True, download=True,
                                                        transform=load_norm),
                                         batch_size= opt.batch_size, shuffle=True)    
    
G = Generator()
D = Discriminator()
optim_G = torch.optim.Adam(G.parameters(),lr=opt.lr)
optim_D = torch.optim.Adam(D.parameters(),lr=opt.lr)

epoch = 0

while epoch < opt.n_epoch:
    for i,(img,_) in enumerate(dataloader):

        z = torch.randn(img.shape[0],opt.channels,opt.latent_dim,opt.latent_dim) # import noise

        optim_D.zero_grad()

        fake_img = G(z).detach()

        real_validity = D(img)
        fake_validity = D(fake_img)

        gradient_penalty = compute_gradient_penalty(D, img, fake_img)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + opt.lambda_gp * gradient_penalty
        d_loss.backward()
        optim_D.step()
        if i % opt.show_freq == 0 and i != 0:
                    plt.figure()
                    plt.imshow(to_img(fake_img[0]))
                    plt.show()
        print ("Step [%d] | Discriminator Loss: [%.4f]" % (i, d_loss.item()))

        if i % opt.n_critic == 0 and i != 0:

            optim_G.zero_grad()

            fake_img = G(z)
            fake_validity = D(fake_img)

            g_loss = - torch.mean(fake_validity)
            g_loss.backward()

            optim_G.step()

            print ("Step [%d] | Generator Loss: [%.4f]" % (i, g_loss.item()))

        if i % opt.sample_interval==0 and i != 0 :
            D_path = os.path.join(opt.model_path, "WGAN_D_Epoch"+str(epoch)+"Step"+str(i)+".pt")
            torch.save(D, D_path)
            G_path = os.path.join(opt.model_path, "WGAN_G_Epoch"+str(epoch)+"Step"+str(i)+".pt")
            torch.save(G, G_path)
            img_path = os.path.join(opt.img_path, "WGAN_img_Epoch"+str(epoch)+"Step"+str(i)+".png")
            save_image(fake_img[:25],img_path, nrow=5, normalize=True,range=(-1,1))
            print('Model & Image saved..')
    