In [None]:
project_name="anime-dcgan"

### Get Dataset

In [None]:
# import opendatasets as od
# dataset_url = "https://www.kaggle.com/splcher/animefacedataset"
# od.download(dataset_url)

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T

In [None]:
image_size = 64
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) #normallize RGB to 0.5

In [None]:
train_ds = ImageFolder('./animefacedataset',transform=T.Compose([
  T.Resize(image_size),#resize to 64x64
  T.CenterCrop(image_size),#center crop to 64x64
  T.ToTensor(),#convert to tensor
  T.Normalize(*stats)]))#normalize to 0.5 from (0,1) to (-1,1)
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)

In [None]:
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def denorm(img_tensors):
    """take image_tensor multiply with standard deviation and add the mean value. Bring it back to (0,1)"""
    return img_tensors * stats[1][0] + stats[0][0]

In [None]:
def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break

In [None]:
show_batch(train_dl)

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)


In [None]:
device = get_default_device()
device

In [None]:
train_dl = DeviceDataLoader(train_dl, device)

### Discriminator Network

In [None]:
# Classifier model (DCGAN)
import torch.nn as nn

discriminator = nn.Sequential(
    # input is 3 x 64 x 64
    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # output is 64 x 32 x 32

    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # output is 128 x 16 x 16

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # output is 256 x 8 x 8

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # output is 512 x 4 x 4

    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # output is 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid()  # output is a probability
    # classification use softmax, but generation use sigmoid
)


In [None]:
import math
import numpy as np


def sigmoid(x):
    a = []
    for item in x:
        a.append(1/(1+math.exp(-item)))
    return a


x = np.arange(-10., 10., 0.2)
sig = sigmoid(x)
plt.plot(x, sig)
plt.show()


In [None]:
discriminator = to_device(discriminator, device)

In [None]:
latent_size = 128 # controller number, more latent more feature e.g. human face 1024 

In [None]:
generator = nn.Sequential(
  #input: latent_size x 1 x 1
  nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
  nn.BatchNorm2d(512),
  nn.ReLU(True),
  #output: 512 x 4 x 4

  nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
  nn.BatchNorm2d(256),
  nn.ReLU(True),
  #output: 256 x 8 x 8

  nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
  nn.BatchNorm2d(128),
  nn.ReLU(True),
  #output: 128 x 16 x 16

  nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
  nn.BatchNorm2d(64),
  nn.ReLU(True),
  #output: 64 x 32 x 32

  nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
  nn.Tanh()
  #output: 3 x 64 x 64
)

In [None]:
## value from (-1,1)
def tanh(x):
    a = []
    for item in x:
        a.append(math.tanh(item))
    return a

x = np.arange(-10., 10., 0.2)
tan = tanh(x)
plt.plot(x, tan)
plt.show()

In [None]:
# Create random latent vector
xb = torch.randn(batch_size, latent_size, 1, 1) # random latent tensors
fake_images = generator(xb)
print(fake_images.shape)
show_images(fake_images)

In [None]:
generator = to_device(generator, device)

### Discriminator Network Training

use binary crossentrophy loss

In [None]:
import torch.nn.functional as F

In [None]:
def train_discriminator(real_images, opt_d):
    # Clear discriminator gradients
    opt_d.zero_grad()

    # Pass real images through discriminator
    real_preds = discriminator(real_images)
    real_targets = torch.ones(real_images.size(0), 1, device=device)
    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()
    
    # Generate fake images
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)

    # Pass fake images through discriminator
    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_preds = discriminator(fake_images)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()

    # Update discriminator weights
    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()
    return loss.item(), real_score, fake_score


### Generator Network Training

In [None]:
def train_generator(opt_g):
    # Clear generator gradients
    opt_g.zero_grad()
    
    # Generate fake images
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)
    
    # Try to fool the discriminator
    preds = discriminator(fake_images)
    targets = torch.ones(batch_size, 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)
    
    # Update generator weights
    loss.backward()
    opt_g.step()
    
    return loss.item()

In [None]:
from torchvision.utils import save_image
import os
sample_dir = "generated"
os.makedirs(sample_dir, exist_ok=True)

In [None]:
def save_samples(index,latent_tensors,show=True):
  fake_images = generator(latent_tensors)
  fake_fname = 'generated-images-{:04d}.png'.format(index)
  save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
  print('Saving', fake_fname)
  if show:
      fig, ax = plt.subplots(figsize=(8, 8))
      ax.set_xticks([]); ax.set_yticks([])
      ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [None]:
fixed_latent = torch.randn(64, latent_size, 1, 1, device=device)

In [None]:
save_samples(0, fixed_latent)

In [None]:
from tqdm.notebook import tqdm

In [None]:
def fit(epochs,lr,start_idx=1):
  torch.cuda.empty_cache()

  #loss and score
  losses_g = []
  losses_d = []
  real_scores = []
  fake_scores = []

  #Create Optimizers
  opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
  opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

  for epoch in range(epochs):
    #Train generator and discriminator
    for real_images, _ in tqdm(train_dl):
      loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
      loss_g = train_generator(opt_g)
    
    #keep record and score
    losses_g.append(loss_g)
    losses_d.append(loss_d)
    real_scores.append(real_score)
    fake_scores.append(fake_score)

    #log losses & scorees (last batch)
    print(f"Epoch [{epoch+1}/{epochs}] Loss_D: {loss_d} Loss_G: {loss_g} Real_Score: {real_score} Fake_Score: {fake_score}")

    #save samples
    save_samples(epoch+start_idx, fixed_latent)

  return losses_g, losses_d, real_scores, fake_scores

In [None]:
lr = 0.0002
epochs = 25

In [None]:
history = fit(epochs, lr)

In [None]:
losses_g, losses_d, real_scores, fake_scores = history

In [None]:
# Save the model checkpoints 
torch.save(generator.state_dict(), 'G.pth')
torch.save(discriminator.state_dict(), 'D.pth')

In [None]:
from IPython.display import Image

In [None]:
import cv2
import os

vid_fname = 'gans_training.avi'

files = [os.path.join(sample_dir, f) for f in os.listdir(sample_dir) if 'generated' in f]
files.sort()

out = cv2.VideoWriter(vid_fname,cv2.VideoWriter_fourcc(*'MP4V'), 1, (530,530))
[out.write(cv2.imread(fname)) for fname in files]
out.release()

In [None]:
plt.plot(losses_d, '-')
plt.plot(losses_g, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Losses');