In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
### Critic Model ### 
class Critic(nn.Module):
  def __init__(self, channels_img, feature_d):
    super(Critic, self).__init__()
    self.critic = nn.Sequential(
        # Input image:  N * num_channnels * 64 * 64
        nn.Conv2d(channels_img, feature_d, kernel_size=(4, 4), stride=(2,2), padding=(1, 1)), # image 32*32
        nn.LeakyReLU(0.2),
        self._block(feature_d, feature_d*2, 4, 2, 1), # feature size: 16 * 16
        self._block(feature_d*2, feature_d*4, 4, 2, 1), # feature size: 8 * 8
        self._block(feature_d*4, feature_d*8, 4, 2, 1), # feature size: 4 * 4
        nn.Conv2d(feature_d*8, 1, kernel_size=(4,4), stride=(2,2), padding=0), # image 1*1*1 a single value to determine real or fake
        #nn.Sigmoid(), # to be between [0,1], but we do not need it for WGAN bec it is unbounded in WGAN

    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), # Bias bec we use the batchnorm
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.critic(x)


In [None]:
### Gen Model ### 
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, feature_g):
    super(Generator, self).__init__()
    # Input image: batch * z_dim * 1 * 1
    self.gen = nn.Sequential(
        self._block(z_dim, feature_g*16, 4, 1, 0), # image : batch * (f_g * 16) * 4 * 4  
        self._block(feature_g*16, feature_g*8, 4, 2, 1), # image : batch * (f_g * 8) * 8 * 8
        self._block(feature_g*8, feature_g*4, 4, 2, 1), # image : batch * (f_g * 4) * 16 * 16
        self._block(feature_g*4, feature_g*2, 4, 2, 1), # image : batch * (f_g * 2) * 32 * 32
        nn.ConvTranspose2d(feature_g*2, channels_img, 4 ,2, 1), # # image : batch * (1) * 64 * 64
        nn.Tanh(), # between [-1, 1]
            )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias= False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),        
    )  

  def forward(self, x):
    return self.gen(x)


In [None]:
# define  a function to initialize the weights
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def test():
  N, C, H, W = 10, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, C, H, W))
  critic = Critic(C, 8)
  initialize_weights(critic)
  assert critic(x).shape == (N, 1, 1, 1)
  gen = Generator(z_dim, C, 8)
  initialize_weights(gen)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, C, H, W)
  print("Success")

In [None]:
test()

In [None]:
# Hyperparameter setting
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-4
num_critic = 5
LAMBDA = 10
batch_size = 128
image_size = 64
channel_img = 1
num_epoches = 5
feature_disc = 16
feature_gen = 16
z_dim = 100
Transforms = transforms.Compose([transforms.Resize(image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5 for _ in range(channel_img)], [0.5 for _ in range(channel_img)]),
                                 ])
## Data Loading
train_dataset = datasets.MNIST(root="/.", train=True, transform=Transforms, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Model instantiation
gen = Generator(z_dim, channel_img, feature_gen).to(device)
critic = Critic(channel_img, feature_disc).to(device)
initialize_weights(gen)
initialize_weights(critic)

In [None]:
# Optimizer setting
opt_critic = optim.Adam(critic.parameters(), lr=lr, betas=(0., 0.9))
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0., 0.9))

In [None]:
%load_ext tensorboard

In [None]:
'''
w = torch.rand(20,2,1,1, requires_grad=True)
x = w**3
y = x.repeat(20,1,64,64)
gradient = torch.autograd.grad(outputs=y, inputs=w, grad_outputs=torch.ones_like(y), 
                               retain_graph=True, create_graph=True)[0]
print("the Gradient shape is:", gradient.shape)
pp = gradient.view(gradient.shape[0], -1)
print("the view shape is:", pp.shape)

N = pp.norm(2,1)
print("the norm shape is:", N.shape)

gpp = torch.mean(N-1)
print(gpp)
'''

In [None]:
## Gradient penalty & 1-L norm satisfaction 
def gradient_penalty(critic, real, fake, device="cpu"):
    B,C, H, W = real.shape
    Epsilon = torch.rand(B, 1, 1, 1).repeat(1, C, H, W)
    Interpolate_image = Epsilon*real + (1-Epsilon)*fake
    critic_score = critic(Interpolate_image)
    
    Grad = torch.autograd.grad(outputs = critic_score,
                               inputs = Interpolate_image,
                               grad_outputs = torch.ones_like(critic_score),
                               retain_graph = True,
                               create_graph = True)[0] # the output is tuple, the [0] helps to extract Tensors, 
                                                       # the gradient shape =inputs.shapee
    
    Grad = Grad.view(Grad.shape[0], -1) # reshape it to obtain norm 2 per sample 
    grad_norm = torch.mean(Grad.norm(2, dim = 1))
    gp = (grad_norm - 1)**2
    return gp    

In [None]:
# Training Process
fixed_noise = torch.randn(batch_size, z_dim, 1, 1).to(device=device) # used for Tensorboard
Write_real = SummaryWriter(f"Logs/real")
Writer_fake = SummaryWriter(f"Logs/fake")
step = 0

gen.train()
critic.train()

for epoch in range(num_epoches):
    for batch_idx, (real, _) in enumerate(train_loader):

        # Train critic
        for t in range(num_critic):
            real = real.to(device=device)
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device=device)
            fake = gen(noise)

            critic_fake = critic(fake).view(-1)
            critic_real = critic(real).view(-1)
            

            gp = gradient_penalty(critic, real, fake, device=device)
            Loss_critic = -((torch.mean(critic_fake) - torch.mean(critic_real)) + LAMBDA*gp) 
            critic.zero_grad()
            Loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator  Max {E[D(G(x)))] 
        gen_fake = critic(fake).view(-1)
        Loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        Loss_gen.backward(retain_graph=True)
        opt_gen.step()

     # Print losses occasionally and print to tensorboard
        if batch_idx % 2 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{num_epoches}] Batch {batch_idx}/{len(train_loader)} \
                Loss D: {Loss_critic:.4f}, loss G: {Loss_gen:.4f}")

            with torch.no_grad():
                fake = gen(fixed_noise)
                        # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                Write_real.add_image("Real", img_grid_real, global_step=step)
                Writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1