<a href="https://colab.research.google.com/github/Machine-Learning-Tokyo/Intro-to-GANs/blob/master/WassersteinGAN/Faces_COND_WDCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Faces Wasserstein GAN (WGAN)

In this notebook we use the same model in [Wasserstein GAN (WGAN)  -- Solutions]() but used to train on a dataset of faces. As a result, our GAN would produce faces of people.



## Donwload dataset

First of all we need a dataset of faces. We're going to use the [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset, a dataset with over 200k pictures of celebrities.

To download it we're going to use an script that belongs to the [StarGAN](https://github.com/yunjey/stargan) project. StarGAN is an advanced GAN that modifies faces of people. There is no need to understand it for this notebook, but go ahead and have a look if you have curiosity.

So run the following cells to download the dataset. This will take a while.

In [0]:
#%%capture
# download CelebA data
!wget https://raw.githubusercontent.com/yunjey/StarGAN/master/download.sh
!bash download.sh celeba

In [0]:
!echo "There are `ls data/celeba/images | wc -l` images"

### Imports

In [0]:
from keras.models import Model
from keras.layers import Input, Dense, BatchNormalization, Reshape, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import fashion_mnist
from keras.optimizers import Adam, RMSprop

from keras.layers import Conv2D, UpSampling2D, concatenate, Lambda
from keras.initializers import RandomNormal
from keras.utils import to_categorical
import keras.backend as K

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
from matplotlib import animation, rc
from IPython.display import Image as ipyImage
from IPython.display import HTML
from os import listdir

## Hidden data loader code

To use a dataset we need a data loader. In our previous WGAN notebooks we used a pre-loaded dataset: fashionMNIST. fashionMNIST is a very easy and rather small dataset so images ar provided inside an array that can be indexed to generate batches. On the contrary, CelebA is a very big one, so it's not feasible to keep it in memory inside an arry. What we're going to do is to implement a loader that will reach the image files in disc and generate the batches on the fly.

The code for this data loader is hidden.  It's cumbersome and it's not the point of this exercise so you can ignore it.

In [0]:
# templates
DATA_DIR = 'data/celeba/'
IMGS_DIR = '{}images/'.format(DATA_DIR)
CSV_FILE = '{}list_attr_celeba.txt'.format(DATA_DIR)

In [0]:
#@title Data loader

import numpy as np
import cv2
import pandas as pd
from itertools import cycle
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from random import shuffle
import random

def download_celeba(data_dir, imgs_dir):
  %rm -r {data_dir}
  %mkdir {data_dir}
  %cd {data_dir}

  !pip install pydrive
  from shutil import unpack_archive
  # these classes allow you to request the Google drive API
  from pydrive.auth import GoogleAuth
  from pydrive.drive import GoogleDrive
  from oauth2client.client import GoogleCredentials

  from googleapiclient.http import MediaIoBaseDownload
  from google.colab import auth
  auth.authenticate_user()

  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)
  file_id_dict = {
      'imgs.zip': '0B7EVK8r0v71pZjFTYXZWM3FlRnM',
      'attributes.txt': '0B7EVK8r0v71pblRyaVFSWGxPY0U'
  }
  for file_, id_ in file_id_dict.items():
    downloaded = drive.CreateFile({'id': id_})
    downloaded.GetContentFile(file_)

#   %rm -r {imgs_dir}*
  unpack_archive('imgs.zip', imgs_dir)
  %rm imgs.zip

def get_imgs_lists(df, category):
  pos_img_names = df.index[df[category] == 1]
  neg_img_names = df.index[df[category] == -1]
  
  return list(pos_img_names), list(neg_img_names)

def preprocess_img(img, img_shape, crop=None):
  img = img / 127.5 - 1
  if crop is not None:
    cropx, cropy = crop
    y,x,_ = img.shape
    startx = x//2-(cropx//2)
    starty = y//2-(cropy//2)
    img = img[starty:starty+cropy,startx:startx+cropx,:]
  img = cv2.resize(img, img_shape[:2])
  
  return img

def single_core_category_generator(img_paths, img_shape):
  imgs = np.zeros((len(img_paths), *img_shape))
  for i, img_path in enumerate(img_paths):
    img = mpimg.imread(img_path)
    imgs[i] = preprocess_img(img, img_shape)
  return imgs


def mp_category_generator(imgs_dir, img_names, img_shape, batch_size):
  import multiprocessing as mp
  cpu_num = min(mp.cpu_count(), batch_size)
  
  def chunks(lst, n):
    for i in range(n):
        yield lst[i::n]
  
  shuffle(img_names)
  img_names = cycle(img_names)
  
  while True:
    imgs = np.zeros((batch_size, *img_shape))
    batch_paths = [imgs_dir + next(img_names) for _ in range(batch_size)]
    batch_paths = chunks(batch_paths, cpu_num)
    
    pool = mp.Pool(processes=cpu_num)
    imgs = [pool.apply(single_core_category_generator,
                       args=(next(batch_paths), img_shape)) for i in range(cpu_num)]
    pool.terminate()
    imgs = np.concatenate(imgs)
    
    yield imgs
    
def category_generator(imgs_dir, img_names, img_shape, batch_size, crop=None):  
  shuffle(img_names)
  img_names = cycle(img_names)
  
  while True:
    imgs = np.zeros((batch_size, *img_shape))
    for i in range(batch_size):
      img_path = imgs_dir + next(img_names)
      img = mpimg.imread(img_path)
      imgs[i] = preprocess_img(img, img_shape, crop=crop)
    
    yield imgs
 
def get_generators(img_shape, batch_size, category=None, download=True, crop=None):
  data_dir = 'data/'
  imgs_dir = '{}celeba/images/'.format(data_dir)

  if download:
    download_celeba(data_dir, imgs_dir)
  
  imgs_dir = imgs_dir
  df = pd.read_csv('{}celeba/list_attr_celeba.txt'.format(data_dir), sep=' +', skiprows=[0])
  
  img_names = list(df.index)
  if category is None:
    gen = category_generator(imgs_dir, img_names, img_shape, batch_size)
    return gen
  
  pos_img_names, neg_img_names = get_imgs_lists(df, category)

  pos_gen = category_generator(imgs_dir, pos_img_names, img_shape, batch_size, crop=crop)
  neg_gen = category_generator(imgs_dir, neg_img_names, img_shape, batch_size, crop=crop)
  
  return pos_gen, neg_gen

### Function to build the generator

In [0]:
def build_generator(noise_size, img_shape, num_classes):
  # block: Conv, Batch norm, Upsampling
  k_size = 5, 5
  k_init = RandomNormal(0, 0.01)
  filters = 1024 #CHANGE
  
  noise = Input((noise_size,))
  labels = Input((num_classes,))
  
  model_input = concatenate([noise, labels])
  
  x = Dense(4*4*filters, kernel_initializer=k_init, activation='relu')(model_input)
  x = Reshape((4, 4, filters))(x)  # 4, 4
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 8, 8
  
  x = Conv2D(filters // 2, k_size, padding='same', kernel_initializer=k_init, activation='relu')(x)
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 16, 16
  
  x = Conv2D(filters // 4, k_size, padding='same', kernel_initializer=k_init, activation='relu')(x)
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 32, 32
  
  #CHANGE
  x = Conv2D(filters // 8, k_size, padding='same', kernel_initializer=k_init, activation='relu')(x)
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 64, 64
  
  img = Conv2D(img_shape[-1], k_size, padding='same', kernel_initializer=k_init, activation='tanh')(x)
  
  generator = Model([noise, labels], img)
  return generator

### Function to build the discriminator

In [0]:
def build_discriminator(img_shape, num_classes):
  # block: Conv, Batch norm, LeakyRelu
  k_size = 5, 5
  k_init = RandomNormal(0, 0.01)
  filters = 1024 #CHANGE
  
  
  img = Input(img_shape)  # 64, 64, 3
  labels = Input((num_classes,))  # 10
  
  n_labels = Reshape((1, 1, -1))(labels)  # (batch_size), 1, 1, 10
  n_labels = Lambda(lambda x: K.tile(x, [1, img_shape[0], img_shape[1], 1]))(n_labels)  # (batch_size), 64, 64, 10
  model_input = concatenate([img, n_labels])  # 64, 64, ? (1 + 10)
  
  #CHANGE
  x = Conv2D(filters // 8, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(model_input)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha=0.2)(x)  #32, 32
  
  x = Conv2D(filters // 4, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(x)#CHANGE, model_input -> x
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha=0.2)(x)  #16, 16
  
  x = Conv2D(filters // 2, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha=0.2)(x)  # 8, 8
  
  x = Conv2D(filters, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha=0.2)(x)  # 4, 4
  
  x = Flatten()(x)
  validity = Dense(3, activation='linear', kernel_initializer=k_init)(x) #CHANGE
  
  discriminator = Model([img, labels], validity)
  return discriminator

### Function to compile the models

In [0]:
def critic_loss(y_true, y_pred):
  return K.mean(y_true * y_pred)

In [0]:
def get_compiled_models(generator, discriminator, noise_size, num_classes):
  
  optimizer = RMSprop(0.0002)
  
  discriminator.compile(optimizer, loss=critic_loss)
  discriminator.trainable = False
  
  noise = Input((noise_size,))
  labels = Input((num_classes,))
  
  img = generator([noise, labels])
  validity = discriminator([img, labels])
  combined = Model([noise, labels], validity)
  
  combined.compile(optimizer, loss=critic_loss)
  
  return generator, discriminator, combined

### Function to sample and save generated images

In [0]:
#FIXME this is old code, unused
def sample_imgs(generator, noise_size, step, plot_img=True, cond=False, num_classes=10):
  np.random.seed(0)
  
  r, c = num_classes, 10
  if cond:
    noise = np.random.normal(0, 1, (c, noise_size))
    noise = np.tile(noise, (r, 1))

    sampled_labels = np.arange(r).reshape(-1, 1)
    sampled_labels = to_categorical(sampled_labels, r)
    sampled_labels = np.repeat(sampled_labels, c, axis=0)

    imgs = generator.predict([noise, sampled_labels])
  else:
    noise = np.random.normal(0, 1, (r*c, noise_size))
    imgs = generator.predict_on_batch(noise)
  
  imgs = imgs / 2 + 0.5
  imgs = np.reshape(imgs, [r, c, imgs.shape[1], imgs.shape[2], -1])
  
  figsize = 1 * c, 1 * r
  fig, axs = plt.subplots(r, c, figsize=figsize)
  
  for i in range(r):
    for j in range(c):
      img = imgs[i, j] if len(imgs.shape) == 4 else imgs[i, j, :, :, 0]
      axs[i, j].imshow(img, cmap='gray')
      axs[i, j].axis('off')
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  fig.savefig(f'/content/images/{step}.png')
  if plot_img:
    plt.show()
  plt.close()
  
  np.random.seed(None)

In [0]:
rows, cols = 4, 7
sampled_labels = np.array([0.0 if i >= rows*cols/2 else 1.0 for i in range(rows*cols)])

def sample_imgs_(generator, g_loss_buffer, noise_size, step):
  test_images = generator.predict([test_noise, sampled_labels])
  fig = plt.figure(1, figsize=(2*1.2*cols, 1.2*rows))
  gs = gridspec.GridSpec(rows, 2*cols)
  for j in range(rows*cols):
    plt.subplot(gs[j//cols, j%cols])#invert!
    plt.imshow(test_images[j-1]/2.0 + 0.5)
    axs = plt.gca()
    if j >= rows*cols/2:
      axs.tick_params(axis=u'both', which=u'both',length=5)
      axs.set_xticks([])
      axs.set_yticks([])
    else:
      axs.axis('off')
  #plot error here
  plt.subplot(gs[:,cols+1:])
  plt.plot(g_loss_buffer)
  plt.grid(True)
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  fig.savefig('/content/images/{}.png'.format(step))
  plt.show()

### Function to train the models

Select the save step for debugging images

In [0]:
SAVE_STEP = 5

In [0]:
def train(models, data_loader, noise_size, img_shape, num_classes, batch_size, steps):
  
  generator, discriminator, combined = models
  pos_loader, neg_loader = data_loader
  #CHANGE delete fashion mnist
  
  g_loss_buffer = []
  for step in range(1, steps + 1):
    for i in range(n_critic):
      # train discriminator
      if (step + i) % 2 == 0:
        real_imgs, labels = next(pos_loader), np.ones(batch_size)
      else:
        real_imgs, labels = next(neg_loader), np.zeros(batch_size)
      
      noise = np.random.normal(0, 1, (batch_size, noise_size))
      gen_imgs = generator.predict([noise, labels])

      gen_validity = np.ones(batch_size)
      real_validity = - np.ones(batch_size)

      r_loss = discriminator.train_on_batch([real_imgs, labels], real_validity)
      g_loss = discriminator.train_on_batch([gen_imgs, labels], gen_validity)
      disc_loss = np.add(r_loss, g_loss) / 2
        
    # clipping
    for layer in discriminator.layers:
      weights = layer.get_weights()
      clipped_weights = [np.clip(w, -c, c) for w in weights]
      layer.set_weights(clipped_weights)
      
    # train generator
    noise = np.random.normal(0, 1, (batch_size, noise_size))
    gen_loss = combined.train_on_batch([noise, labels], -np.ones(batch_size))
    g_loss_buffer.append(gen_loss)
    
    #print progress
    if step % SAVE_STEP == 0:
      print('step: %d, D_loss: %f, G_loss: %f' % (step, disc_loss, gen_loss))
    
    # save_samples
    if step % SAVE_STEP == 0:
      sample_imgs_(generator, g_loss_buffer, noise_size, step)
      
    # save model
    if step % 1000 == 0:
      generator.save('faces_g_step{}.h5'.format(step))

### Define hyperparameters

In [0]:
%rm -r /content/images
%mkdir /content/images
noise_size = 100
img_shape = 64, 64, 3 #CHANGE
num_classes = 1 #CHANGE
batch_size = 32
steps = 100000

c = 0.01
n_critic = 5

#CHANGE
category = 'Male'
data_loader = get_generators(img_shape, batch_size, category, download=False, crop=(150, 150))
test_noise = np.random.normal(size=(rows*cols, noise_size))

### Generate the models

In [0]:
generator = build_generator(noise_size, img_shape, num_classes)
discriminator = build_discriminator(img_shape, num_classes)
compiled_models = get_compiled_models(generator, discriminator, noise_size, num_classes)

### Train the models

In [0]:
train(compiled_models, data_loader, noise_size, img_shape, num_classes, batch_size, steps)

## Plot resutls

### Display samples

Let's start by checking the images that we have stored.

In [0]:
%ls /content/images

You can check any image you wish by doing:

In [0]:
image_number = SAVE_STEP
ipyImage('/content/images/%d.png' % image_number)

### Do an animation

Probably the best way of showing the training process is by doing an animation with all the images. The next cell will do it for you.

In [0]:
path = '/content/images/{}.png'

In [0]:
class AnimObject(object):
    def __init__(self, images):
        print(len(images))
        self.fig, self.ax = plt.subplots()
        self.ax.set_title("")
        self.fig.set_size_inches((20, 10))
        self.plot = plt.imshow(images[0])
        plt.tight_layout()
        self.images = images
        
    def init(self):
        self.plot.set_data(self.images[0])
        self.ax.grid(False)
        return (self.plot,)
      
    def animate(self, i):
        self.plot.set_data(self.images[i])
        self.ax.grid(False)
        self.ax.set_xticks([])
        self.ax.set_yticks([])
        self.ax.set_title("index {}".format(i))
        return (self.plot,)

def get_figures(template, indices):
    import os.path
    images = []
    for index in indices:
        if os.path.isfile(template.format(index)):
            images.append(Image.open(template.format(index)))
    return images


images = get_figures("/content/images/{}.png", 
                     range(0, SAVE_STEP * len(listdir('/content/images')) + 1, SAVE_STEP))
print(images)
animobject = AnimObject(images)
anim = animation.FuncAnimation(
              animobject.fig,
              animobject.animate,
              frames=len(animobject.images),
              interval=150,
              blit=True)

In [0]:
HTML(anim.to_jshtml())

## Download images and generator

This code is to download the trained model to store it locally on your computer. You probably shouldn't bother about it.

In [0]:
gen_path = '/content/fashion_cond_w_dcgan_gen.h5'
generator.save(gen_path)
from google.colab import files
files.download(gen_path)