# Common parts

In [2]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torch.nn import functional as F
import torchvision
torch.manual_seed(0) # Set for testing purposes, please do not change!

<torch._C.Generator at 0x25c7f20e210>

In [3]:
class Generator(nn.Module):
    def __init__(self, C_noise, C_hidden, C_image):
        super(Generator, self).__init__()

        self.gen = nn.Sequential(
            self.gen_block(C_noise   , C_hidden*4, K=3, S=2),
            self.gen_block(C_hidden*4, C_hidden*2, K=4, S=1),
            self.gen_block(C_hidden*2, C_hidden*1, K=3, S=2),
            self.gen_block(C_hidden*1  , C_image , K=4, S=2, final_layer=True)
        )

    def gen_block(self, C_in, C_out, K, S, final_layer=False):
        if final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(C_in, C_out, kernel_size = K, stride = S),
                nn.Tanh()
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(C_in, C_out, kernel_size = K, stride = S),
                nn.BatchNorm2d(C_out),
                nn.ReLU(inplace=True)
            )

    def forward(self, noise):
        return self.gen(noise)

In [4]:
class Discriminator(nn.Module):
    def __init__(self, C_image, C_hidden):
        super(Discriminator, self).__init__()

        self.dis = nn.Sequential(
            self.dis_block(C_image   , C_hidden*1, K=4, S=2),
            self.dis_block(C_hidden*1, C_hidden*2, K=4, S=2),
            self.dis_block(C_hidden*2, 1         , K=4, S=2, final_layer=True),
        )

    def dis_block(self, C_in, C_out, K, S, final_layer=False):
        if final_layer:
            return nn.Sequential(
                nn.Conv2d(C_in, C_out, kernel_size = K, stride = S)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(C_in, C_out, kernel_size = K, stride = S),
                nn.BatchNorm2d(C_out),
                nn.LeakyReLU(inplace=True, negative_slope=0.2)
            )

    def forward(self, x):
        return self.dis(x).view(x.shape[0], -1)

In [5]:
def get_noise(N_noise, C_noise, device='cpu'):
    return torch.randn(N_noise, C_noise, device=device).view(-1, C_noise, 1, 1)

In [17]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [56]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [18]:
n_epochs = 100
C_noise = 64
display_step = 500
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'
C_hidden_gen = 64
C_hidden_dis = 16
C_image=1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

# Compare

### 1. Coursera

In [21]:
# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
def get_gradient_1(crit, real, fake, epsilon):
    mixed_images = real * epsilon + fake * (1 - epsilon)
    mixed_scores = crit(mixed_images)
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

In [22]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
def gradient_penalty_1(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = ( ( gradient_norm - 1.0 )**2 ).mean(dim=0)
    return penalty

In [25]:
# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
def get_gen_loss_1(crit_fake_pred):
    gen_loss = (-crit_fake_pred).mean(dim=0)
    return gen_loss

In [24]:
# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
def get_crit_loss_1(crit_fake_pred, crit_real_pred, gp, c_lambda):
    crit_loss = ( crit_fake_pred - crit_real_pred + c_lambda*gp ).mean(dim=0)
    return crit_loss

In [53]:
# get instance from models
gen = Generator(C_noise, C_hidden_gen, C_image).to(device)
dis = Discriminator(C_image, C_hidden_dis).to(device)

# Initialize
gen = gen.apply(weights_init)
dis = dis.apply(weights_init)

# Optimizers
optim_dis = torch.optim.Adam(dis.parameters(), lr=lr, betas=(beta_1, beta_2))
optim_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))

In [54]:
# RESULT
real = torch.ones(128,1,28,28).to(device)
cur_batch_size = len(real)
fake_noise = get_noise(cur_batch_size, C_noise, device=device)
fake = gen(fake_noise)
crit_fake_pred = dis(fake.detach())
crit_real_pred = dis(real)
epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
gradient = get_gradient_1(dis, real, fake.detach(), epsilon)
gp = gradient_penalty_1(gradient)
crit_loss = get_crit_loss_1(crit_fake_pred, crit_real_pred, gp, c_lambda)
crit_loss

tensor([8.1294], device='cuda:0', grad_fn=<MeanBackward1>)

In [58]:
import matplotlib.pyplot as plt

cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            optim_dis.zero_grad()
            fake_noise = get_noise(cur_batch_size, C_noise, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = dis(fake.detach())
            crit_real_pred = dis(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient_1(dis, real, fake.detach(), epsilon)
            gp = gradient_penalty_1(gradient)
            crit_loss = get_crit_loss_1(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            optim_dis.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        optim_gen.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, C_noise, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = dis(fake_2)

        gen_loss = get_gen_loss_1(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        optim_gen.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/469 [00:00<?, ?it/s]

KeyboardInterrupt: 

### 3. Me

In [44]:
def get_epsilon(N_epsilon, device='cpu'):
    return torch.rand(N_epsilon, 1, 1, 1, device=device, requires_grad=True)

In [45]:
def get_gradient_2(dis, data):
    data.requires_grad_()
    score = dis(data)
    gradient = torch.autograd.grad(
        inputs=data,
        outputs=score,
        grad_outputs=torch.ones_like(score),
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

In [46]:
def gradient_penalty_2(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = ( ( gradient_norm - 1.0 )**2 ).mean(dim=0)
    return penalty

In [47]:
def get_loss_dis_2(gen, dis,
                real,
                N_noise, C_noise,
                c_lambda,
                device):
    fake = gen(get_noise(N_noise, C_noise, device))
    epsilon = get_epsilon(N_epsilon = N_noise, device=device)
    mixed_images = real * epsilon + fake * (1 - epsilon)
    gp = gradient_penalty_2(get_gradient_2(dis, mixed_images))
    return ( dis(mixed_images) - dis(real) + c_lambda * gp ).mean(dim=0)

In [48]:
def get_loss_gen_2(gen, dis,
                N_noise, C_noise,
                device):
    fake = gen(get_noise(N_noise, C_noise, device))
    return torch.mean( -dis(fake) )

In [49]:
# get instance from models
gen = Generator(C_noise, C_hidden_gen, C_image).to(device)
dis = Discriminator(C_image, C_hidden_dis).to(device)

# Initialize
gen = gen.apply(weights_init)
dis = dis.apply(weights_init)

# Optimizers
optim_dis = torch.optim.Adam(dis.parameters(), lr=lr, betas=(beta_1, beta_2))
optim_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))

In [55]:
real = torch.ones(128,1,28,28).to(device)
cur_batch_size = len(real)
loss_dis = get_loss_dis_2(gen, dis, real, cur_batch_size, C_noise, c_lambda, device)
loss_dis

tensor([7.9603], device='cuda:0', grad_fn=<MeanBackward1>)

In [59]:
import matplotlib.pyplot as plt

cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            optim_dis.zero_grad()
            crit_loss = get_loss_dis_2(gen, dis,
                real,
                cur_batch_size, C_noise,
                c_lambda,
                device)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            optim_dis.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        optim_gen.zero_grad()

        gen_loss = get_loss_gen_2(gen, dis,
                cur_batch_size, C_noise,
                device)
        gen_loss.backward()

        # Update the weights
        optim_gen.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
# new
            fake_noise = get_noise(cur_batch_size, C_noise, device=device)
            fake = gen(fake_noise)
# end new
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1


  0%|          | 0/469 [00:00<?, ?it/s]

KeyboardInterrupt: 