In [1]:
# we will train a conditional wgan on svhn dataset
# we will use the gradient penalty to stabilize the training


# imports

In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# hyperparameters for the dataset and the MOdel

In [3]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 1
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

# prepare the dataset

### we will use SVHN dataset for this example
### we will combine the train, test and extra datasets to make a bigger dataset

In [4]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
    ]
)

In [5]:
#get the dataset
#train part of svhn
train_dataset = datasets.SVHN(root="dataset_svhm/", split='train', transform=transforms, download=True)
#test part of svhn
test_dataset = datasets.SVHN(root="dataset_svhm/", split='test', transform=transforms, download=True)
#extra part of svhn
extra_dataset = datasets.SVHN(root="dataset_svhm/", split='extra', transform=transforms, download=True)
#concatenate the train, test and extra dataset
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset, extra_dataset])
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)


Using downloaded and verified file: dataset_svhm/train_32x32.mat
Using downloaded and verified file: dataset_svhm/test_32x32.mat
Using downloaded and verified file: dataset_svhm/extra_32x32.mat


In [6]:
#print the total number of images in the dataset
print(len(dataset))
# print the shape of the images
print(dataset[0][0].shape)
# print the label of the image
print(dataset[0][1])

630420
torch.Size([3, 64, 64])
1


# Model


## generator

In [7]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self.block_architecture_generator(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self.block_architecture_generator(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self.block_architecture_generator(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self.block_architecture_generator(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def block_architecture_generator(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

In [8]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self.block_architecture_critic(features_d, features_d * 2, 4, 2, 1),
            self.block_architecture_critic(features_d * 2, features_d * 4, 4, 2, 1),
            self.block_architecture_critic(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def block_architecture_critic(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

--------------------------------

# model initialization

In [9]:

def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)



In [10]:
# initialize gen and critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

### initialize optimizer

In [11]:
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

### inintialize tensorboard

In [12]:
# for tensorboard plotting
fixed_noise = torch.randn(100, Z_DIM, 1, 1).to(device)
#plot loss of generator and critic
writer_loss = SummaryWriter(f"runs/conditional_WGAN/loss")
writer_real = SummaryWriter(f"logs/conditional_WGAN/real")
writer_fake = SummaryWriter(f"logs/conditional_WGAN/fake")

--------------------------------

## start training

In [13]:
step = 0

gen.train()
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(128, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

#### gradient penalty function for WGAN-GP

In [14]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [15]:
for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/1] Batch 100/9851                   Loss D: -55.8379, loss G: 44.9305
Epoch [0/1] Batch 200/9851                   Loss D: -69.6663, loss G: 80.7570
Epoch [0/1] Batch 300/9851                   Loss D: -51.9374, loss G: 107.8354
Epoch [0/1] Batch 400/9851                   Loss D: -36.5594, loss G: 132.8058
Epoch [0/1] Batch 500/9851                   Loss D: -46.3237, loss G: 150.9261
Epoch [0/1] Batch 600/9851                   Loss D: -45.5233, loss G: 162.4883
Epoch [0/1] Batch 700/9851                   Loss D: -36.0344, loss G: 162.1082
Epoch [0/1] Batch 800/9851                   Loss D: -31.0289, loss G: 171.5005
Epoch [0/1] Batch 900/9851                   Loss D: -31.2241, loss G: 172.8750
Epoch [0/1] Batch 1000/9851                   Loss D: -13.3897, loss G: 163.6735
Epoch [0/1] Batch 1100/9851                   Loss D: -13.0480, loss G: 153.9199
Epoch [0/1] Batch 1200/9851                   Loss D: -13.3457, loss G: 134.1922
Epoch [0/1] Batch 1300/9851            

In [16]:
#  sudo fuser -k /dev/nvidia
