## Implementing Wasserstein GAN paper (2017)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

import import_ipynb # import another nbs like modules .py
from WGAN_models import Critic, Generator, initialize_weights


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps") # macOS

# This hiperparameters are from the paper
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3 
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5 # for each train iteration of the generator we will train 5 times the critic net
WEIGHT_CLIP = 0.01 # 'c' from the paper

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)])
    ]
)

dataset = datasets.ImageFolder(root = "./dataset_GAN/celeba/", transform = transforms)
loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

critic = Critic(CHANNELS_IMG, FEATURES_DISC).to(device)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)

# setting the initial weights for the gen and the critic
initialize_weights(critic)
initialize_weights(gen)

# hiperperameters taken from the paper
opt_gen = optim.RMSprop(gen.parameters(), lr = LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr = LEARNING_RATE)

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)


# Tensorboard
writer_real = SummaryWriter(f'runs/GAN_CelebA/real_WGAN')
writer_fake = SummaryWriter(f'runs/GAN_CelebA/fake_WGAN')
step = 0

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader): # we dont use the labels 
        real = real.to(device)

        
        ### We have to train the critic more than the generator
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) # E(real[f(x)]) - E(fake[f(x)])
            # we want to maximaze loss_critic, but our optim method find the minimum values, so we put "-"
            critic.zero_grad()
            loss_critic.backward(retain_graph = True)
            opt_critic.step()
        
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
        
        ### Train Generator: min -E[critic(gen_fake)]
        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        ### Print losses and print to tensorboard
        if batch_idx % 1000 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)}\
                    Loss D: {loss_critic:.4f}, Loss G: {loss_gen:.4f}"
            )
            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize = True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize = True
                )
                
                writer_real.add_image("Real_WGAN", img_grid_real, global_step= step)
                writer_fake.add_image("Fake_WGAN", img_grid_fake, global_step= step)
                
            step += 1

## **WGAN-GP (WASSERSTEIN w/ GRADIENT PENALTY)**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

import import_ipynb # import another nbs like modules .py
from WGAN_models import Generator, initialize_weights, Critic_WGANGP


### SET THE GRADIENT PENALTY ACCORDING THE PAPER
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    
    interpolated_images = (real * alpha) + (fake * (1 - alpha))

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    
    return gradient_penalty


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps") # macOS

# This hiperparameters are from the paper
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5 # for each train iteration of the generator we will train 5 times the critic net
LAMBDA_GP = 10 # Lambda for the gradient penalty, according to the paper

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)])
    ]
)

dataset = datasets.ImageFolder(root = "./dataset_GAN/celeba/", transform = transforms)
loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

# dataset = datasets.MNIST(root = "dataset_GAN/", train = True, transform = transforms, download = True)
# loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

critic = Critic_WGANGP(CHANNELS_IMG, FEATURES_DISC).to(device)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)

# setting the initial weights for the gen and the critic
initialize_weights(critic)
initialize_weights(gen)

# hiperperameters taken from the paper
opt_gen = optim.Adam(gen.parameters(), lr = LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr = LEARNING_RATE, betas=(0.0, 0.9))

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)


# Tensorboard
writer_real = SummaryWriter(f'runs/GAN_CelebA/real_WGAN-GP')
writer_fake = SummaryWriter(f'runs/GAN_CelebA/fake_WGAN-GP')
step = 0

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader): # we dont use the labels 
        real = real.to(device)

        
        ### We have to train the critic more than the generator
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            
            ### ADDING THE GRADIENT PENALTY
            gp = gradient_penalty(critic= critic, real= real, fake = fake, device = device)
            loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp) # [E(real[f(x)]) - E(fake[f(x)])] + Gradient_Penalty
            
            critic.zero_grad()
            loss_critic.backward(retain_graph = True)
            opt_critic.step()
        
        
        ### Train Generator: min -E[critic(gen_fake)]
        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        ### Print losses and print to tensorboard
        if batch_idx % 1000 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)}\
                    Loss D: {loss_critic:.4f}, Loss G: {loss_gen:.4f}"
            )
            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize = True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize = True
                )
                
                writer_real.add_image("REAL_WGAN-GP", img_grid_real, global_step= step)
                writer_fake.add_image("FAKE_WGAN-GP", img_grid_fake, global_step= step)
                
            step += 1