## Dataset
https://www.kaggle.com/splcher/animefacedataset

In [None]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

seed = random.randint(1,10000000)
random.seed(seed)
torch.manual_seed(seed)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

In [None]:
dataroot = 'dataset/'
batch_size = 128
img_size = 64
num_channel = 3
workers = 2
num_z = 3
num_generator_feature = 64
num_discriminator_feature = 64
num_epochs = 100
lr = 0.0002
beta1 = 0.5

In [None]:
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.CenterCrop(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(w.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
            nn.init.normal_(w.weight.data, 1.0, 0.02)
            nn.init.constant_(w.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.ngpu = 1
        self.main = nn.Sequential(
            nn.ConvTranspose2d( num_z, num_generator_feature * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(num_generator_feature * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(num_generator_feature * 8, num_generator_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_generator_feature * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d( num_generator_feature * 4, num_generator_feature * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_generator_feature * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d( num_generator_feature * 2, num_generator_feature, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_generator_feature),
            nn.ReLU(True),
            nn.ConvTranspose2d( num_generator_feature, num_channel, 4, 2, 1, bias=False),
            nn.Tanh()
            )

    def forward(self, input):
        return self.main(input)

In [None]:
net_generator = Generator().to(device)
net_generator.apply(weights_init)
print(net_generator)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.ngpu = 1
        self.main = nn.Sequential(
            nn.Conv2d(num_channel, num_discriminator_feature, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_discriminator_feature, num_discriminator_feature * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_discriminator_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_discriminator_feature * 2, num_discriminator_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_discriminator_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_discriminator_feature * 4, num_discriminator_feature * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_discriminator_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_discriminator_feature * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
net_discriminator = Discriminator().to(device)
net_discriminator.apply(weights_init)
print(net_discriminator)

In [None]:
loss  = nn.BCELoss()
real_label = 1.
fake_label = 0.
fixed_noise = torch.randn(64, num_z, 1, 1, device=device)
optimizer_discriminator = optim.Adam(net_discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_generator = optim.Adam(net_generator.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
img_list = []
generator_losses = []
discriminator_losses = []
iters = 0

for epoch in range(num_epochs):
    for i,data in enumerate(dataloader,0):
        net_discriminator.zero_grad()
        real = data[0].to(device)
        b_size = real.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = net_discriminator(real).view(-1)
        error_discriminator_real = loss(output, label)
        error_discriminator_real.backward()
        D_x = output.mean().item()
        
        noise = torch.randn(b_size, num_z, 1, 1, device=device)
        fake = net_generator(noise)
        label.fill_(fake_label)
        output = net_discriminator(fake.detach()).view(-1)
        error_discriminator_fake = loss(output, label)
        error_discriminator_fake.backward()
        D_G_z1 = output.mean().item()
        
        error_discriminator = error_discriminator_real + error_discriminator_fake
        optimizer_discriminator.step()
        
        net_generator.zero_grad()
        label.fill_(real_label)
        output = net_discriminator(fake).view(-1)
        error_generator = loss(output, label)
        error_generator.backward()
        D_G_z2 = output.mean().item()
        
        optimizer_generator.step()
        
        if i % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     error_discriminator.item(), error_generator.item(), D_x, D_G_z1, D_G_z2))
        generator_losses.append(error_generator.item())
        discriminator_losses.append(error_discriminator.item())
        
        if (iters % 8 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = net_generator(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        iters += 1

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]