In [1]:
import os 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.utils as vutils
import torchvision.transforms as transforms

In [2]:
# data_path = '../datasets/celeba_data'
data_path = '../datasets/anime_data'
workers = 2
img_size = 64
batch_size = 128

# number of channels and latent size
nc = 3
nz = 100

# feature map sizes (generator and discriminator)
ngf = 64
ndf = 64

num_train_d = 5
weight_clip = 0.01

num_epochs = 15
lr_d = 5e-4
lr_g = 2e-4
beta1 = 0.5


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = datasets.ImageFolder(
    root=data_path,
    transform=transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5)
        )
    ])
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers
)

In [4]:
print('using device', device)

using device cuda


In [5]:
def weights_init(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif class_name.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(),
            
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.net(x)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
        )
        
    def forward(self, x):
        return self.net(x)

In [8]:
generator = Generator().to(device)
generator.apply(weights_init)

Generator(
  (net): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [9]:
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)

Discriminator(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [10]:
optimizer_d = optim.RMSprop(discriminator.parameters(), lr=lr_d)
optimizer_g = optim.RMSprop(generator.parameters(), lr=lr_g)

def get_noise(batch_size, noise_dim):
    return torch.randn(batch_size, noise_dim, 1, 1, device=device)

In [11]:
os.makedirs('./anime_wgan_results', exist_ok=True)
loss_d_list, loss_g_list = [], []

for epoch in range(num_epochs):
    total_loss_d, total_loss_g = 0, 0
    for i, (images, _) in tqdm(enumerate(dataloader), total=len(dataloader)):
        batch_size = images.size(0)
        
        for _ in range(num_train_d):
            z = get_noise(batch_size, nz)
            fake_images = generator(z)
            outputs_real = discriminator(images.to(device)).view(-1)
            outputs_fake = discriminator(fake_images.detach()).view(-1)
            
            loss_d = torch.mean(outputs_fake) - torch.mean(outputs_real)
            
            discriminator.zero_grad()
            loss_d.backward(retain_graph=True)
            optimizer_d.step()
            
            for p in discriminator.parameters():
                p.data.clamp_(-weight_clip, weight_clip)
            
        outputs = discriminator(fake_images).view(-1)
        loss_g = -1 * torch.mean(outputs)
        
        generator.zero_grad()
        loss_g.backward()
        optimizer_g.step()
        
        total_loss_d += loss_d.item()
        total_loss_g += loss_g.item()
    
    total_loss_d /= len(dataloader)
    total_loss_g /= len(dataloader)
    loss_d_list.append(total_loss_d)
    loss_g_list.append(total_loss_g)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}')
    
    fake_images = fake_images.reshape(batch_size, 3, img_size, img_size)
    vutils.save_image(fake_images, f'./anime_wgan_results/fake_anime_results_{epoch+1}.png', normalize=True)
        

100%|██████████| 497/497 [06:24<00:00,  1.29it/s]

Epoch [1/15], Loss D: -1.3020, Loss G: 0.6365



100%|██████████| 497/497 [06:29<00:00,  1.28it/s]

Epoch [2/15], Loss D: -1.1692, Loss G: 0.6112



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [3/15], Loss D: -0.8118, Loss G: 0.5806



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [4/15], Loss D: -0.8242, Loss G: 0.2824



100%|██████████| 497/497 [06:28<00:00,  1.28it/s]

Epoch [5/15], Loss D: -0.5926, Loss G: 0.0574



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [6/15], Loss D: -0.8474, Loss G: 0.5345



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [7/15], Loss D: -0.7706, Loss G: 0.2364



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [8/15], Loss D: -0.5848, Loss G: 0.0868



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [9/15], Loss D: -0.4325, Loss G: 0.0275



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [10/15], Loss D: -0.5365, Loss G: 0.1703



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [11/15], Loss D: -0.7135, Loss G: 0.1311



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [12/15], Loss D: -0.5011, Loss G: -0.0685



100%|██████████| 497/497 [06:27<00:00,  1.28it/s]

Epoch [13/15], Loss D: -0.4604, Loss G: 0.0403



100%|██████████| 497/497 [06:28<00:00,  1.28it/s]

Epoch [14/15], Loss D: -0.5371, Loss G: 0.3334



100%|██████████| 497/497 [06:29<00:00,  1.28it/s]

Epoch [15/15], Loss D: -0.5702, Loss G: 0.5588



