### Goals
In this notebook, you're going to build a Wasserstein GAN with Gradient Penalty (WGAN-GP) that solves some of the stability issues with the GANs that you have been using up until this point. Specifically, you'll use a special kind of loss function known as the W-loss, where W stands for Wasserstein, and gradient penalties to prevent mode collapse.

*Fun Fact: Wasserstein is named after a mathematician at Penn State, Leonid Vaseršteĭn. You'll see it abbreviated to W (e.g. WGAN, W-loss, W-distance).*

### Learning Objectives
1.   Get hands-on experience building a more stable GAN: Wasserstein GAN with Gradient Penalty (WGAN-GP).
2.   Train the more advanced WGAN-GP model.



## Generator and Critic

You will begin by importing some useful packages, defining visualization functions, building the generator, and building the critic. Since the changes for WGAN-GP are done to the loss function during training, you can simply reuse your previous GAN code for the generator and critic class. Remember that in WGAN-GP, you no longer use a discriminator that classifies fake and real as 0 and 1 but rather a critic that scores images with real numbers.

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fc6808f89b0>

In [2]:
def show_tensor_images(image_tensor, num_images=25, size = (1, 28, 28)):
    image_tensor = (image_tensor +1)/2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()

def make_grad_hook():
    grads = []
    def grad_hook(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            grads.append(m.weight.grad)
    return grads, grad_hook

# Create generator class

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan = 1, hidden_dim=64):
        super(Generator, self).__init__()

        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim*4),
            self.make_gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim*2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, 4, final_layer=True)
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU( inplace = True)
            )

        else: return nn.Sequential(
            nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
            nn.Tanh()
        )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_sample, z_dim, device='cpu'):
    return torch.randn(n_sample, z_dim, device=device)


# Creating Critic class

In [4]:
class Critic(nn.Module):
    def __init__(self, im_chan = 1, hidden_dim=64):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            self.make_critic_block(im_chan, hidden_dim),
            self.make_critic_block(hidden_dim, hidden_dim*2),
            self.make_critic_block(hidden_dim*2, 1, final_layer=True)
        )

    def make_critic_block(self, input_channels, output_channels, kernel_size =3, stride = 2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace = True)
            )
        else: return nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size, stride),
        )

    def forward(self, image):
        critic_pred = self.critic(image)
        return critic_pred.view(len(critic_pred), -1)


## Training Initializations
Now you can start putting it all together.
As usual, you will start by setting the parameters:
  *   n_epochs: the number of times you iterate through the entire dataset when training
  *   z_dim: the dimension of the noise vector
  *   display_step: how often to display/visualize the images
  *   batch_size: the number of images per forward/backward pass
  *   lr: the learning rate
  *   beta_1, beta_2: the momentum terms
  *   c_lambda: weight of the gradient penalty
  *   crit_repeats: number of times to update the critic per generator update - there are more details about this in the *Putting It All Together* section
  *   device: the device type

You will also load and transform the MNIST dataset to tensors.




In [5]:
n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
critic_repeats = 5
device = 'cpu'



In [None]:
transform= transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

In [15]:
# Initialize Gen and Critic

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr = lr, betas = (beta_1, beta_2))

critic = Critic().to(device)
critic_opt = torch.optim.Adam(critic.parameters(), lr= lr, betas = (beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply((weights_init))
critic = critic.apply((weights_init))

## Gradient Penalty
Calculating the gradient penalty can be broken into two functions: 

> * (1) compute the gradient with respect to the images
> * (2) compute the gradient penalty given the gradient

You can start by getting the gradient. The gradient is computed by first creating a mixed image. This is done by weighing the fake and real image using epsilon and then adding them together. Once you have the intermediate image, you can get the critic's output on the image. Finally, you compute the gradient of the critic score's on the mixed images (output) with respect to the pixels of the mixed images (input). You will need to fill in the code to get the gradient wherever you see *None*. There is a test function in the next block for you to test your solution.

## Gradient wrt Images

In [8]:
def get_gradient(critic, real, fake, epsilon):
    mixed_images = real * epsilon * fake *(1-epsilon)
    mixed_scores = critic(mixed_images)

    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs= torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]
    return gradient

In [9]:
# Unit test for grainet wrt images

def test_get_gradient(image_shape):
    real = torch.randn(*image_shape, device=device) +1
    fake = torch.randn(*image_shape, device=device) -1
    epsilon_shape = [1 for _ in image_shape]
    epsilon_shape[0] = image_shape[0]
    epsilon = torch.randn(epsilon_shape, device=device).requires_grad_()
    gradinet = get_gradient(critic, real, fake, epsilon)

    assert tuple(gradinet.shape) == image_shape
    assert gradinet.max() > 0
    assert gradinet.min() < 0
    return gradinet

gradinet = test_get_gradient(image_shape = (256, 1, 28, 28))
print("Success...")


Success...


## Gradient penality wrt given gradient

In [12]:
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)

    penalty = torch.mean((gradient_norm - 1)**2)
    return penalty

# unit test for graident penality wrt given gradient

def test_gradient_penalty(image_shape):
    bad_gradient = torch.zeros(*image_shape)
    bad_gradient_penalty = gradient_penalty(bad_gradient)
    assert torch.isclose(bad_gradient_penalty, torch.tensor(1.))

    image_size = torch.prod(torch.Tensor(image_shape[1:]))
    good_gradient = torch.ones(*image_shape) / torch.sqrt(image_size)
    good_gradient_penalty = gradient_penalty(good_gradient)
    assert torch.isclose(good_gradient_penalty, torch.tensor(0.))

    random_gradient = test_get_gradient(image_shape)
    random_gradient_penality = gradient_penalty(random_gradient)
    assert torch.abs(random_gradient_penality -1) < 0.1

test_gradient_penalty((256, 1, 28, 28))
print('Success...')

Success...


# Losses

In [13]:
def get_gen_loss(critic_fake_pred):
    gen_loss = -1. * torch.mean(critic_fake_pred)
    return gen_loss

# unit test for gen loss 
assert torch.isclose(get_gen_loss(torch.tensor(1.)), torch.tensor(-1.))

assert torch.isclose(get_gen_loss(torch.rand(1000)), torch.tensor(-0.5), 0.05)

In [14]:
def get_critic_loss(critic_fake_pred, critic_real_pred, gp, c_lambda):
    critic_loss = torch.mean(critic_fake_pred) - torch.mean(critic_real_pred) + c_lambda*gp
    return critic_loss

# unit test for critic loss
assert torch.isclose(
    get_critic_loss(torch.tensor(1.), torch.tensor(2.), torch.tensor(3.), 0.1), torch.tensor(-0.7)
    )

assert torch.isclose(
    get_critic_loss(torch.tensor(20.), torch.tensor(-20.), torch.tensor(2.), 10), torch.tensor(60.)
)

print('Success...')

Success...


# Training

In [None]:
cur_step = 0
gen_losses = []
critic_losses = []

for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(critic_repeats):
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)

            critic_opt.zero_grad()
            critic_fake_pred = critic(fake.detach())
            critic_real_pred = critic(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, require_grad=True)
            gradient = get_gradient(critic, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            critic_loss = get_critic_loss(critic_fake_pred, critic_real_pred, gp, c_lambda)

            mean_iteration_critic_loss += critic_loss.item() / critic_repeats
            critic_loss.backward(retain_graph= True)
            critic_opt.step()
        critic_losses += [mean_iteration_critic_loss]

        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)

        gen_opt.zero_grad()
        critic_fake_pred = critic(fake_2)
        gen_loss = get_gen_loss(critic_fake_pred)
        gen_loss.backward()

        gen_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(gen_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(gen_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(gen_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1