In [None]:
import numpy as np
from tqdm import tqdm

import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torchvision.utils as vutils

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML


# Hyper Parameters

In [None]:
n_epochs = 20
batch_size = 32
lr=0.0002
b1=0.5
b2=0.999
n_cpu = 8
latent_dim = 100 
img_size = 64
channels = 3
n_critic = 5
clip_value = 0.01
sample_interval = 400
ngf = 64
ndf = 64


dataroot = "./data"
workers = 6

img_shape = (channels, img_size, img_size)

cuda = True if torch.cuda.is_available() else False

# Visualize Dataset

In [None]:
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.CenterCrop(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

# Generator

In [None]:
from models.generator_wgan import * 

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Create the generator
generator = Generator().to(device)

generator.apply(weights_init) 

# Print the model
print(generator)

# Discriminator

In [None]:
from models.discriminator_wgan import *

In [None]:
# Create the Discriminator
discriminator = Discriminator().to(device)
# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
discriminator.apply(weights_init)
# Print the model
print(discriminator)

# Loss and Optimizer

In [None]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Loss weight for gradient penalty
lambda_gp = 10

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

fixed_noise = Variable(Tensor(np.random.normal(0, 1, (batch_size, latent_dim))))

# Gradient Penalty

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return gradient_penalty


# Training

In [None]:
# ----------
#  Training
# ----------

img_list = []
batches_done = 0
one = torch.tensor(1, dtype=torch.float)
mone = one * -1
one = one.to(device)
mone = mone.to(device)

for epoch in range(n_epochs):

    for i, (imgs, _) in enumerate(tqdm(dataloader, total=len(dataloader))):
        # Requires grad, Generator requires_grad = False
        for p in discriminator.parameters():
            p.requires_grad = True

        d_loss_real = 0
        d_loss_fake = 0
        Wasserstein_D = 0
        # Train Dicriminator forward-loss-backward-update self.critic_iter times while 1 Generator forward-loss-backward-update

        # ---------------------
        #  Train Discriminator
        # ---------------------
        for d_iter in range(n_critic):
            discriminator.zero_grad()

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))
            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

            # Train discriminator
            # WGAN - Training discriminator more iterations than generator
            # Train with real images
            d_loss_real = discriminator(real_imgs)
            d_loss_real = d_loss_real.mean()
            d_loss_real.backward(mone)

            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

            # Generate a batch of images
            fake_imgs = generator(z)
            d_loss_fake = discriminator(fake_imgs)
            d_loss_fake = d_loss_fake.mean()
            d_loss_fake.backward(one)

            # Train with gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            gradient_penalty.backward()
            
            # Adversarial loss
            d_loss = d_loss_fake - d_loss_real + gradient_penalty
            Wasserstein_D = d_loss_real - d_loss_fake
            optimizer_D.step()
            #print(f'  Discriminator iteration: {d_iter}/{n_critic}, loss_fake: {d_loss_fake}, loss_real: {d_loss_real}')
            
        # Generator update
        for p in discriminator.parameters():
            p.requires_grad = False  # to avoid computation    
        
        # -----------------
        #  Train Generator
        # -----------------
        generator.zero_grad()

        # compute loss with fake images
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        fake_imgs = generator(z)
        g_loss = discriminator(fake_imgs)
        g_loss = g_loss.mean()
        g_loss.backward(mone)
        g_cost = -g_loss
        optimizer_G.step()
        #print(f'Generator iteration: {epoch}/{n_epochs}, g_loss: {g_loss}')
        # Saving model and sampling images every 1000th generator iterations

        if i % 500 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tWassertein distance: %.4f\t'
                  % (epoch, n_epochs, i, len(dataloader),
                     d_loss.data, g_loss.data, Wasserstein_D.data))

        # Check how the generator is doing by saving G's output on fixed_noise
        if (i % 500 == 0) or ((epoch == n_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = generator(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
    
    
    plt.figure(figsize=(8,8))
    plt.plot()
    plt.axis("off")
    plt.imshow(np.transpose(img_list[-1],(1,2,0)))
    plt.show()

# Plotting the results

## Loss plot

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(g_loss,label="G")
plt.plot(d_loss,label="D")
plt.plot(Wasserstein_D)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

## Images generated during training

In [None]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

## Compare real and generated images

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()