In [1]:
import os 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.utils as vutils
import torchvision.transforms as transforms

In [2]:
data_path = '../datasets/celeba_data'
# data_path = '../datasets/anime_data'
workers = 2
img_size = 64
batch_size = 128

# number of channels and latent size
nc = 3
nz = 100

# feature map sizes (generator and discriminator)
ngf = 64
ndf = 64

num_train_d = 5
lambda_gp = 10

num_epochs = 5
lr_d = 5e-4
lr_g = 2e-4


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = datasets.ImageFolder(
    root=data_path,
    transform=transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5)
        )
    ])
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers
)

In [4]:
print('using device', device)

using device cuda


In [5]:
def weights_init(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif class_name.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.net(x)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
        )
        
    def forward(self, x):
        return self.net(x)

In [8]:
generator = Generator().to(device)
generator.apply(weights_init)

Generator(
  (net): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [9]:
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)

Discriminator(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [10]:
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0, 0.9))
optimizer_g = optim.Adam(generator.parameters(), lr=lr_g, betas=(0, 0.9))

def get_noise(batch_size, noise_dim):
    return torch.randn(batch_size, noise_dim, 1, 1, device=device)

In [11]:
def gradient_penalty(real_data, fake_data):
    batch_size, c, h, w = real_data.size()
    alpha = torch.rand(batch_size, 1, 1, 1, device=device).repeat(1, c, h, w)
    
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    disc_interpolates = discriminator(interpolates)
    
    gradients = torch.autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size(), device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    return torch.mean((gradients.norm(2, dim=1) - 1) ** 2)

In [None]:
os.makedirs('./celeba_wgangp_results', exist_ok=True)
loss_d_list, loss_g_list = [], []

for epoch in range(num_epochs):
    total_loss_d, total_loss_g = 0, 0
    for i, (images, _) in tqdm(enumerate(dataloader), total=len(dataloader)):
        batch_size = images.size(0)
        
        for _ in range(num_train_d):
            z = get_noise(batch_size, nz)
            fake_images = generator(z)
            outputs_real = discriminator(images.to(device)).view(-1)
            outputs_fake = discriminator(fake_images.detach()).view(-1)
            
            gp = gradient_penalty(images.to(device), fake_images)
            loss_d = torch.mean(outputs_fake) - torch.mean(outputs_real) + lambda_gp * gp
            
            discriminator.zero_grad()
            loss_d.backward(retain_graph=True)
            optimizer_d.step()
            
            
        outputs = discriminator(fake_images).view(-1)
        loss_g = -1 * torch.mean(outputs)
        
        generator.zero_grad()
        loss_g.backward()
        optimizer_g.step()
        
        total_loss_d += loss_d.item()
        total_loss_g += loss_g.item()
    
    total_loss_d /= len(dataloader)
    total_loss_g /= len(dataloader)
    loss_d_list.append(total_loss_d)
    loss_g_list.append(total_loss_g)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}')
    
    fake_images = fake_images.reshape(batch_size, 3, img_size, img_size)
    vutils.save_image(fake_images, f'./celeba_wgangp_results/fake_celeba_results{epoch+1}.png', normalize=True)
        