<a href="https://colab.research.google.com/github/GarlandZhang/hairy_gan/blob/master/hairy_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import files
import os

import pandas as pd
import os
import shutil
if not os.path.exists('kaggle.json'):
  shutil.copy('/content/drive/My Drive/hairy_gan/kaggle.json', 'kaggle.json')
  # !pip install -q kaggle
  # files.upload()
  !mkdir -p ~/.kaggle
  !cp kaggle.json ~/.kaggle/
  !kaggle datasets download -d jessicali9530/celeba-dataset --force
  !unzip celeba-dataset.zip
  !mv img_align_celeba celeba-dataset
  !mv list_eval_partition.csv celeba-dataset/list_eval_partition.csv
  !mv list_landmarks_align_celeba.csv celeba-dataset/list_landmarks_align_celeba.csv
  !mv list_attr_celeba.csv celeba-dataset/list_attr_celeba.csv
  !mv list_bbox_celeba.csv celeba-dataset/list_bbox_celeba.csv

  !mkdir celeba-dataset/train
  !mkdir celeba-dataset/validation
  !mkdir celeba-dataset/test

  complete_df = pd.read_csv('celeba-dataset/list_attr_celeba.csv')
  partitions_df = pd.read_csv('celeba-dataset/list_eval_partition.csv') # 0 => train, 1 => validation, 2 => test
  for i, set_name in enumerate(['train', 'validation', 'test']):
    set_ids_df = partitions_df.loc[partitions_df['partition'] == i]['image_id']
    set_ids = set_ids_df.tolist()
    for id in set_ids:
      shutil.copy(os.path.join('celeba-dataset/img_align_celeba', id), os.path.join('celeba-dataset', f'{set_name}', id))

  !git clone https://www.github.com/keras-team/keras-contrib.git \
    && cd keras-contrib \
    && pip install git+https://www.github.com/keras-team/keras-contrib.git \
    && python convert_to_tf_keras.py \
    && USE_TF_KERAS=1 python setup.py install

  !pip install scipy==1.1.0

In [None]:
from __future__ import print_function, division
import scipy
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.optimizers import Adam
from keras.models import load_model

import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
from glob import glob
from PIL import Image

import tensorflow as tf
from tensorflow.python.keras.backend import set_session
# from tensorflow.python.keras.models import load_model
tf.compat.v1.enable_eager_execution()
# tf.compat.v1.disable_eager_execution()

from tqdm import tqdm

Using TensorFlow backend.


In [None]:
class DataLoader():
    def __init__(self, dataset_name, img_res):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, dataset_type, batch_size=1):
        data_type = dataset_type
        path = glob('%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = scipy.misc.imresize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path = glob('%s/%s/*' % (self.dataset_name, data_type))

        self.n_batches = int(len(path) / batch_size)
        total_samples = self.n_batches * batch_size

        path = np.random.choice(path, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch = path[i*batch_size:(i+1)*batch_size]
            imgs = []
            for img_path in batch:
                img = self.imread(img_path)

                img = scipy.misc.imresize(img, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img = np.fliplr(img)

                imgs.append(img)

            imgs = np.array(imgs)/127.5 - 1.

            yield imgs

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [8]:
def build_encoder(img_shape, num_filters=64, kernel_size=4, strides=2):
  def build_conv(x, num_filters, kernel_size, strides):
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x

  img = Input(shape=img_shape)
  x = build_conv(img, num_filters, kernel_size, strides)
  x = build_conv(x, num_filters * 2, kernel_size, strides)
  x = build_conv(x, num_filters * 4, kernel_size, strides)
  x = build_conv(x, num_filters * 8, kernel_size, strides)
  x = build_conv(x, num_filters * 16, kernel_size, strides)

  return Model(img, x)

def build_embedding(img, img_shape, attr_size):
  label = Input(shape=(attr_size, ), dtype='int32')
  label_embedding = Embedding(2, np.prod(img_shape), input_length=attr_size)(label)
  label_embedding = Flatten()(label_embedding)
  label_embedding = Reshape(img_shape)(label_embedding)
  emb_img = Concatenate(axis=-1)([img, label_embedding])
  return emb_img

def build_decoder(img_shape, attr_size, num_filters=64, kernel_size=4, strides=1):
  def build_deconv(x, num_filters, kernel_size, strides):
    x = UpSampling2D(size=2)(x)
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x

  img = Input(shape=img_shape)

  emb_img = build_embedding(img, img_shape, attr_size)

  x = build_deconv(emb_img, num_filters * 16, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 8, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 4, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 2, kernel_size=kernel_size, strides=strides)
  x = UpSampling2D(size=2)(x)
  x = Conv2D(3, kernel_size=kernel_size, strides=strides, padding='same', activation='tanh')(x)

  return Model([img, label], x)

def build_convnet(img, num_filters=64, kernel_size=4, strides=2):
  def build_conv(x, num_filters, kernel_size, strides):
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    return x
  
  x = build_conv(img, num_filters, kernel_size, strides)
  x = build_conv(x, num_filters * 2, kernel_size, strides)
  x = build_conv(x, num_filters * 4, kernel_size, strides)
  x = build_conv(x, num_filters * 8, kernel_size, strides)
  x = build_conv(x, num_filters * 16, kernel_size, strides)
  x = Flatten()(x)
  x = Dense(num_filters * 1024)
  x = InstanceNormalization()(x)
  x = LeakyReLU()(x)

  return x

def build_discriminator(img_shape):
  img = Input(shape=img_shape)
  emb_img = build_embedding(img, img_shape, attr_size)
  x = build_convnet(emb_img)
  x = Dense(1)
  return Model(img, x)

def build_classifier(img_shape, attr_size):
  img = Input(shape=img_shape)
  x = build_convnet(img)
  x = Dense(attr_size, activation='sigmoid')
  return Model(img, x)

def build_combined(img_shape, attr_size, genc, gdec, d, c):
  d.trainable = False

  x_a = Input(shape=img_shape) # original image
  a = Input(shape=(attr_size, )) # original attr
  b = Input(shape=(attr_size, )) # requested attr
  
  z = genc(x_a) # latent space representation of original image
  x_b = gdec([z, b]) # image with requested attr

  b_hat = c(x_b) # guess of the requested features

  valid = d([x_b, x_a]) # guess if real or fake

  x_a_hat = gdec([z, a]) # reconstr

  combined = Model(
      inputs=[x_a, a, b],
      outputs=[x_b, b_hat, valid, x_a_hat]
  )

  return combined

(2, 1)

In [None]:
class HairyGan(): # based on AttGan
  def __init__(self):

    self.img_rows = 128
    self.img_cols = 128
    self.img_channels = 3

    self.img_shape = (self.img_rows, self.img_cols, self.img_channels)
    
    patch = int(self.img_rows / 2**4)
    self.disc_out = (patch, patch, 1) # output shape of discriminator

    self.data_loader = DataLoader(dataset_name='celeba-dataset', img_res=(self.img_rows, self.img_cols))

    self.optimizer = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)

  def train(self):
    genc = build_encoder(self.img_shape)
    

  ## TODO: re-implement this method without domains
  def sample_images(self, epoch, batch_i):
    print(f'Epoch: {epoch} with batch: {batch_i}')
    rows, cols = 2, 3

    imgs_A = self.data_loader.load_data(batch_size=1, is_testing=True)
    imgs_B = self.data_loader.load_data(domain='B', batch_size=1, is_testing=True)

    fake_A = self.g_BA.predict(imgs_B)
    fake_B = self.g_AB.predict(imgs_A)

    reconstr_A = self.g_BA.predict(fake_B)
    reconstr_B = self.g_AB.predict(fake_A)

    gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axes = plt.subplots(rows, cols)

    count = 0

    for i in range(rows):
      for j in range(cols):
        axes[i, j].imshow(gen_imgs[count])
        axes[i, j].set_title(titles[j])
        axes[i, j].axis('off')
        count += 1

    plt.show()
    

In [None]:
project_path = '/content/drive/My Drive/hairy_gan'
num_attributes = 40
gan = HairyGan()
gan.train(epochs=100, batch_size=64, sample_interval=100)

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  'Discrepancy between trainable weights and collected trainable'


Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1


Train 0 / 100:   0%|          | 1/466 [00:30<3:55:21, 30.37s/it]

Epoch 1/1


FailedPreconditionError: ignored

In [None]:
gan.g_AB = load_model('/content/drive/My Drive/hairy_gan/g_AB.hdf5', custom_objects={'InstanceNormalization': InstanceNormalization})
gan.g_BA = load_model('/content/drive/My Drive/hairy_gan/g_BA.hdf5', custom_objects={'InstanceNormalization': InstanceNormalization})

In [None]:
gan.sample_images()