# Wasserstein GAN (WGAN)

In [1]:
import os
import numpy as np

import torch
import torch.nn as nn

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
data_path = '../data'
os.makedirs(data_path, exist_ok=True)

In [3]:
img_size = 32 # size of each image dimension
channels = 3 # number of image channels
img_shape = (channels, img_size, img_size)
hidden_dim = 64
lr = 0.0002
n_cpu = os.cpu_count()//2 
batch_size = 256 if torch.cuda.is_available() else 64
n_epochs = 10 
noise_dim = 100 # dimensionality of the latent space

## Only used for WGAN
n_critic = 5 # number of training steps for disc. per iter
clip_value = 0.01 # lower and upper clip value for disc. weights

In [4]:
transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

loader_kwargs = {'num_workers': n_cpu, 'pin_memory': True} 

train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, **loader_kwargs)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, **loader_kwargs)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class Generator(nn.Module):
    def __init__(self, img_size=img_size):
        super(Generator, self).__init__()
        
        self.lin = nn.Linear(noise_dim, hidden_dim * img_size * img_size)

        self.conv = nn.Sequential(
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(True),
            nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.lin(z)
        out = out.view(out.shape[0], hidden_dim, img_size, img_size)
        img = self.conv(out)
        return img

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(channels, hidden_dim, 3,padding=1),
            nn.LeakyReLU(True),
            nn.Dropout2d(),
            nn.Conv2d(hidden_dim, hidden_dim, 3,padding=1),
            nn.LeakyReLU(True),
            nn.Dropout2d(),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, hidden_dim, 3,padding=1),
            nn.LeakyReLU(True),
            nn.Dropout2d(),
            nn.BatchNorm2d(hidden_dim),
        )

        self.adv_layer = nn.Sequential(
            nn.Linear(hidden_dim * img_size * img_size, 1), 
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [7]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Optimizers
optimizer_G = torch.optim.AdamW(generator.parameters())
optimizer_D = torch.optim.AdamW(discriminator.parameters())

generator.to(device)
discriminator.to(device)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=True)
    (2): Dropout2d(p=0.5, inplace=False)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): LeakyReLU(negative_slope=True)
    (5): Dropout2d(p=0.5, inplace=False)
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): LeakyReLU(negative_slope=True)
    (9): Dropout2d(p=0.5, inplace=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=65536, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

## Training

In [9]:
saved_dir = 'wgan_images'
os.makedirs(saved_dir, exist_ok = True)

In [10]:
loss_estimate = []
batches_done = 0
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(train_loader):

        # ===== Discriminator =====
        # Train the generator every n_critic iterations
        for _ in range(n_critic):
            optimizer_D.zero_grad()

            # Configure input
            real_imgs = imgs.to(device)
            
            # Sample noise
            z = torch.normal(0, 1, (imgs.shape[0], noise_dim), device=device)

            fake_imgs = generator(z).detach()
            
            loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
            loss_D.backward()
            optimizer_D.step()

            # Clip weights of discriminator
            for p in discriminator.parameters():
                p.data.clamp_(-clip_value, clip_value)


        # ===== Generator =====
        optimizer_G.zero_grad()

        # Sample noise
        z = torch.normal(0, 1, (imgs.shape[0], noise_dim), device=device)

        gen_imgs = generator(z)

        loss_G = -torch.mean(discriminator(gen_imgs))
        loss_G.backward()
        optimizer_G.step()

    # ===== save images and print logs =====
    save_image(gen_imgs.data[:25], f"{saved_dir}/{epoch+1}.png", nrow=5, normalize=True)
    print(f"epoch: {epoch+1}/{n_epochs}, D loss: {loss_D.item():.4f}, G loss: {loss_G.item():.4f}")

epoch: 1/10, D loss: -0.4462, G loss: -0.2699
epoch: 2/10, D loss: -0.4393, G loss: -0.3918
epoch: 3/10, D loss: -0.3528, G loss: -0.2999
epoch: 4/10, D loss: -0.3601, G loss: -0.2689
epoch: 5/10, D loss: -0.4256, G loss: -0.3053
epoch: 6/10, D loss: -0.3600, G loss: -0.3750
epoch: 7/10, D loss: -0.3359, G loss: -0.3823
epoch: 8/10, D loss: -0.3452, G loss: -0.4666
epoch: 9/10, D loss: -0.3007, G loss: -0.3703
epoch: 10/10, D loss: -0.3216, G loss: -0.2814
