In [18]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt

from tqdm import tqdm

# 1- WGPAN

In [19]:
class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super().__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            # 32 x 32
            nn.LeakyReLU(0.2),
            # no batch norm on first layer
            self._block(features_d, features_d*2, 4, 2, 1), # (32x32 -> 16x16)
            self._block(features_d*2, features_d*4, 4, 2, 1), # (16x16 -> 8x8)
            self._block(features_d*4, features_d*8, 4, 2, 1), # (8x8 -> 4x4)
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # (4x4 -> 1x1)
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False, # batch norm handles bias
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

In [20]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super().__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, features_g*16, 4, 1, 0), # (1x1 -> 4x4)
            self._block(features_g*16, features_g*8, 4, 2, 1), # (4x4 -> 8x8)
            self._block(features_g*8, features_g*4, 4, 2, 1), # (8x8 -> 16x16)
            self._block(features_g*4, features_g*2, 4, 2, 1), # (16x16 -> 32x32)
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1
            ), # (32x32 -> 64x64)
            # no batch norm on last layer
            nn.Tanh(), # [-1, 1]
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False, # batch norm handles bias
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(), # no leaky relu (same as paper)
        )

    def forward(self, x):
        return self.gen(x)

In [21]:
# Initialize weights (taken from DCGAN paper)
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

# testing discriminator and generator
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Critic(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"

In [26]:
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 5e-5 # from DCGAN paper
batch_size = 64
image_size = 64 # from DCGAN paper
channels_img = 1
z_dim = 100
features_d = 64
features_g = 64
num_epochs = 5
critic_iterations = 5
weight_clip = 0.01

transform = transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]
            # general for any number of channels
        ),
    ]
)

In [27]:
# prepare dataset
dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [28]:
# define networks
gen = Generator(z_dim, channels_img, features_g).to(device)
critic = Critic(channels_img, features_d).to(device)

# initialize weights
initialize_weights(gen)
initialize_weights(critic)

# optimizers
opt_gen = optim.RMSprop(gen.parameters(), lr=learning_rate)
opt_critic = optim.RMSprop(critic.parameters(), lr=learning_rate)

# set to train mode
gen.train()
critic.train()

# fixed noise
fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)

# tensorboard writers
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0


In [1]:
# Training
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(tqdm(loader)):
        real = real.to(device) # real images


        for _ in range(critic_iterations):
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
            fake = gen(noise) # fake images

            critic_real = critic(real).view(-1) # real images
            critic_fake = critic(fake).view(-1) # fake images
            
            loss_critic = -(torch.mean(critic_fake) - torch.mean(critic_real)) # Maximize loss

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            # Clip weights
            for p in critic.parameters():
                p.data.clamp_(-weight_clip, weight_clip)

        # Train Generator: min -E[critic(gen_fake)]
        output = critic(fake).view(-1)
        lossG = -torch.mean(output)
        # Backpropagation (generator)
        gen.zero_grad()
        lossG.backward()
        # Update weights
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx == 0:
            print(
                f"Epoch: [{epoch}/{num_epochs}] \ "
                f"LossD: {loss_critic:.8f}, LossG: {lossG:.8f}"
            )
            with torch.no_grad():
                fake = gen(fixed_noise)
                data = real

                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                # write to tensorboard
                # writer fake images
                writer_fake.add_image(
                    "Fake Images", img_grid_fake, global_step=step
                )
                # writer real images
                writer_real.add_image(
                    "Real Images", img_grid_real, global_step=step
                )
                step += 1

                plt.imshow(img_grid_fake.to("cpu").permute(1, 2, 0))
                plt.show()


# 2- WGAN-GP 
WGAN-GP is a variant of WGAN that uses gradient penalty to enforce the Lipschitz constraint instead of weight clipping. 

weight clipping is a bad way to enforce Lipschitz constraint because it can lead to undesired behavior such as vanishing gradients or exploding gradients. 

In [None]:
# gradient penalty function
def gradient_penatly(critic, real, fake, device = "cpu"):
    # get batch size, channels, height, width
    batch_size, C, H, W = real.shape
    # generate random epsilon (0, 1)
    epsilon = torch.rand((batch_size, 1, 1, 1)).repeat(1, C, H, W).to(device)
    # interpolate between real and fake images
    interpolated_images = real * epsilon + fake * (1 - epsilon)
    # calculate critic scores
    mixed_scores = critic(interpolated_images)
    # gradient of scores wrt interpolated images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    # flatten gradient
    gradient = gradient.view(gradient.shape[0], -1)
    # calculate norm of gradient
    gradient_norm = gradient.norm(2, dim=1)
    # gradient penalty
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
# Using Instance Norm instead of Batch Norm in the Critic
class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super().__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            # 32 x 32
            nn.LeakyReLU(0.2),
            # no batch norm on first layer
            self._block(features_d, features_d*2, 4, 2, 1), # (32x32 -> 16x16)
            self._block(features_d*2, features_d*4, 4, 2, 1), # (16x16 -> 8x8)
            self._block(features_d*4, features_d*8, 4, 2, 1), # (8x8 -> 4x4)
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # (4x4 -> 1x1)
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False, # batch norm handles bias
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

In [None]:
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 1e-4 # from DCGAN paper
batch_size = 64
image_size = 64 # from DCGAN paper
channels_img = 1
z_dim = 100
features_d = 64
features_g = 64
num_epochs = 5
critic_iterations = 5
lambda_gp = 10

transform = transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]
            # general for any number of channels
        ),
    ]
)

In [None]:
# define networks
gen = Generator(z_dim, channels_img, features_g).to(device)
critic = Critic(channels_img, features_d).to(device)

# initialize weights
initialize_weights(gen)
initialize_weights(critic)

# optimizers
opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas = (0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=learning_rate, betas = (0.0, 0.9))

# set to train mode
gen.train()
critic.train()

# fixed noise
fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)

# tensorboard writers
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0


In [None]:
# Training
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(tqdm(loader)):
        real = real.to(device) # real images


        for _ in range(critic_iterations):
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
            fake = gen(noise) # fake images

            critic_real = critic(real).view(-1) # real images
            critic_fake = critic(fake).view(-1) # fake images
            
            # gradient penalty
            gp = gradient_penatly(critic, real, fake, device=device)
            
            loss_critic = - (torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp # Maximize loss

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            # Clip weights
            for p in critic.parameters():
                p.data.clamp_(-weight_clip, weight_clip)

        # Train Generator: min -E[critic(gen_fake)]
        output = critic(fake).view(-1)
        lossG = -torch.mean(output)
        # Backpropagation (generator)
        gen.zero_grad()
        lossG.backward()
        # Update weights
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx == 0:
            print(
                f"Epoch: [{epoch}/{num_epochs}] \ "
                f"LossD: {loss_critic:.8f}, LossG: {lossG:.8f}"
            )
            with torch.no_grad():
                fake = gen(fixed_noise)
                data = real

                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                # write to tensorboard
                # writer fake images
                writer_fake.add_image(
                    "Fake Images", img_grid_fake, global_step=step
                )
                # writer real images
                writer_real.add_image(
                    "Real Images", img_grid_real, global_step=step
                )
                step += 1

                plt.imshow(img_grid_fake.to("cpu").permute(1, 2, 0))
                plt.show()
