<a href="https://colab.research.google.com/github/GarlandZhang/hairy_gan/blob/master/hairy_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import files
import os

import pandas as pd
import os
import shutil
if not os.path.exists('kaggle.json'):
  shutil.copy('/content/drive/My Drive/hairy_gan/kaggle.json', 'kaggle.json')
  # !pip install -q kaggle
  # files.upload()
  !mkdir -p ~/.kaggle
  !cp kaggle.json ~/.kaggle/
  !kaggle datasets download -d jessicali9530/celeba-dataset --force
  !unzip celeba-dataset.zip
  !mv img_align_celeba celeba-dataset
  !mv list_eval_partition.csv celeba-dataset/list_eval_partition.csv
  !mv list_landmarks_align_celeba.csv celeba-dataset/list_landmarks_align_celeba.csv
  !mv list_attr_celeba.csv celeba-dataset/list_attr_celeba.csv
  !mv list_bbox_celeba.csv celeba-dataset/list_bbox_celeba.csv

  !mkdir celeba-dataset/train
  !mkdir celeba-dataset/validation
  !mkdir celeba-dataset/test

  partitions_df = pd.read_csv('celeba-dataset/list_eval_partition.csv') # 0 => train, 1 => validation, 2 => test
  for i, set_name in enumerate(['train', 'validation', 'test']):
    set_ids_df = partitions_df.loc[partitions_df['partition'] == i]['image_id']
    set_ids = set_ids_df.tolist()
    for id in set_ids:
      shutil.copy(os.path.join('celeba-dataset/img_align_celeba', id), os.path.join('celeba-dataset', f'{set_name}', id))

  !git clone https://www.github.com/keras-team/keras-contrib.git \
    && cd keras-contrib \
    && pip install git+https://www.github.com/keras-team/keras-contrib.git \
    && python convert_to_tf_keras.py \
    && USE_TF_KERAS=1 python setup.py install

  !pip install scipy==1.1.0

In [2]:
from __future__ import print_function, division
import scipy
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, Embedding
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.optimizers import Adam
from keras.models import load_model, save_model

import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
from glob import glob
from PIL import Image

import tensorflow as tf
from tensorflow.python.keras.backend import set_session, clear_session
# from tensorflow.python.keras.models import load_model
# tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_v2_behavior()
# tf.compat.v1.enable_eager_execution()

from tqdm import tqdm

Using TensorFlow backend.


Instructions for updating:
non-resource variables are not supported in the long term


In [19]:
!rm -r celeba-dataset/train_filter
!mkdir celeba-dataset/train_filter
# extract images of particular class for training
num_images_each = 5
feature = 'Bald'
complete_df = pd.read_csv('celeba-dataset/list_attr_celeba.csv')
for img_id in complete_df.loc[complete_df[feature] == 1][:num_images_each].filter(['image_id']).to_numpy():
  img_id = img_id[0]
  shutil.copy(f'celeba-dataset/train/{img_id}', f'celeba-dataset/train_filter/{img_id}')

for img_id in complete_df.loc[complete_df[feature] == -1][:num_images_each].filter(['image_id']).to_numpy():
  img_id = img_id[0]
  shutil.copy(f'celeba-dataset/train/{img_id}', f'celeba-dataset/train_filter/{img_id}')

In [28]:
class DataLoader():
    def __init__(self, dataset_name, img_res):
        self.dataset_name = dataset_name
        self.img_res = img_res
        self.complete_df = pd.read_csv('celeba-dataset/list_attr_celeba.csv')
        # self.features = ['Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses', 'Gender', 'Mouth_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Age']
        # self.num_attrs = 9 # should equal to length of self.features
        self.features = ['Bald']
        self.num_attrs = len(self.features)

    def load_data(self, dataset_type, batch_size=1, is_testing=False):
        data_type = dataset_type
        path = glob('%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        attribs = []
        
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = scipy.misc.imresize(img, self.img_res)
            imgs.append(img)

            # get attributes

            img_attribs = [(val + 1) // 2 for val in self.complete_df.loc[self.complete_df['image_id'] == os.path.basename(img_path)].filter(items=self.features).to_numpy()[0]]

            attribs.append(img_attribs)

        imgs = np.array(imgs)/127.5 - 1.
        attribs = np.array(attribs)

        return imgs, attribs

    def load_batch(self, batch_size=1, is_testing=False, is_filter=False):
        if is_filter:
          data_type = 'train_filter'
        elif is_testing:
          data_type = 'test'
        else:
          data_type = 'train'
        path = glob('%s/%s/*' % (self.dataset_name, data_type))

        self.n_batches = int(len(path) / batch_size)
        total_samples = self.n_batches * batch_size

        path = np.random.choice(path, total_samples, replace=False)

        i = 0
        while i < self.n_batches - 1:
            batch = path[i*batch_size:(i+1)*batch_size]
            imgs = []
            attribs = []
            for img_path in batch:
                img = self.imread(img_path)

                img = scipy.misc.imresize(img, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img = np.fliplr(img)

                imgs.append(img)

                # get attributes

                img_attribs = np.array([(val + 1) // 2 for val in self.complete_df.loc[self.complete_df['image_id'] == os.path.basename(img_path)].filter(items=self.features).to_numpy()[0]])

                attribs.append(img_attribs)

            imgs = np.array(imgs)/127.5 - 1.
            attribs = np.array(attribs)

            i += 1
            if i == self.n_batches - 1:
              # reset
              path = glob('%s/%s/*' % (self.dataset_name, data_type))
              path = np.random.choice(path, total_samples, replace=False)
              i = 0

            yield imgs, attribs

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [5]:
def build_encoder(img_shape, num_filters=64, kernel_size=4, strides=2):
  def build_conv(x, num_filters, kernel_size, strides):
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x

  img = Input(shape=img_shape)
  x = build_conv(img, num_filters, kernel_size, strides)
  x = build_conv(x, num_filters * 2, kernel_size, strides)
  x = build_conv(x, num_filters * 4, kernel_size, strides)
  x = build_conv(x, num_filters * 8, kernel_size, strides)
  x = build_conv(x, num_filters * 16, kernel_size, strides)
  # x.name = 'encoder_output'

  model = Model(img, x, name='encoder')

  model.summary()

  return model

def build_embedding(img, label, input_shape, attr_size):
  label_embedding = Embedding(2, np.prod(input_shape), input_length=attr_size)(label)
  # style_embedding = Embedding(2, np.prod(input_shape), input_length=attr_size)(style)
  # label_style_embedding = Add()([label_embedding, style_embedding])
  # label_style_embedding = Reshape(input_shape[:-1] + (attr_size * input_shape[-1], ))(label_style_embedding)
  # emb_img = Concatenate(axis=-1)([img, label_style_embedding])
  label_embedding = Reshape(input_shape[:-1] + (attr_size * input_shape[-1], ))(label_embedding)
  emb_img = Concatenate(axis=-1)([img, label_embedding])
  return emb_img

def build_decoder(latent_space_shape, attr_size, num_filters=64, kernel_size=4, strides=1):
  def build_deconv(x, num_filters, kernel_size, strides):
    x = UpSampling2D(size=2)(x)
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x

  img = Input(shape=latent_space_shape)
  label = Input(shape=(attr_size, ), dtype='int32')

  emb_img = build_embedding(img, label, latent_space_shape, attr_size)

  x = build_deconv(emb_img, num_filters * 16, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 8, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 4, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 2, kernel_size=kernel_size, strides=strides)
  x = UpSampling2D(size=2)(x)
  x = Conv2D(3, kernel_size=kernel_size, strides=strides, padding='same', activation='tanh')(x)
  # x.name = 'decoder_output'

  model = Model([img, label], x, name='decoder')

  model.summary()

  return model

def build_convnet(img, num_filters=64, kernel_size=4, strides=2):
  def build_conv(x, num_filters, kernel_size, strides):
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    return x
  
  x = build_conv(img, num_filters, kernel_size, strides)
  x = build_conv(x, num_filters * 2, kernel_size, strides)
  x = build_conv(x, num_filters * 4, kernel_size, strides)
  x = build_conv(x, num_filters * 8, kernel_size, strides)
  x = build_conv(x, num_filters * 16, kernel_size, strides)
  x = Flatten()(x)
  x = Dense(1024)(x)
  x = InstanceNormalization()(x)
  x = LeakyReLU()(x)

  return x

def build_dc(img_shape, attr_size, optimizer): # NOTE: we ignore inputting original image to discriminator head. why? cause im not sure if its important
  img = Input(shape=img_shape)
  label = Input(shape=(attr_size, ), dtype='int32') # I don't understand. why do we have this?

  # emb_img = build_embedding(img, label, img_shape, attr_size)
  # x = build_convnet(emb_img)
  x = build_convnet(img)
  disc_output = Dense(1, name='disc_output')(x)
  classif_output = Dense(attr_size, activation='sigmoid', name='classif_output')(x)

  dc = Model([img, label], [disc_output, classif_output], name='dc')

  dc.compile(loss=['binary_crossentropy', 'binary_crossentropy'], loss_weights=[1, 1], optimizer=optimizer)

  dc.summary()

  return dc

def build_combined_generator(img_shape, attr_size, genc, gdec, dc, optimizer):
  dc.trainable = False

  x_a = Input(shape=img_shape) # original image
  a = Input(shape=(attr_size, )) # original attr
  b = Input(shape=(attr_size, )) # requested attr
  
  z = genc(x_a) # latent space representation of original image
  x_b = gdec([z, b]) # image with requested attr

  valid, b_hat = dc([x_b, b]) # guess real or fake and guess the requested features 

  x_a_hat = gdec([z, a]) # reconstr

  combined = Model(
      inputs=[x_a, a, b],
      outputs=[b_hat, valid, x_a_hat],
      name='combined'
  )

  combined.compile(loss=['binary_crossentropy', 'binary_crossentropy', 'mae'], loss_weights=[10, 1, 100], optimizer=optimizer)

  combined.summary()

  return combined

In [6]:
def shuffle(elems):
  new_elems = elems.copy()
  np.random.shuffle(new_elems)
  return new_elems

def create_random_attrs(attrs, count):
  attr_size = attrs[0].size
  
  new_attrs = np.random.randint(0, 2, size=(count, attr_size))
  for r in range(count):
    for c in range(attr_size):
      if attrs[r, c] == 1 and new_attrs[r, c] == 0:
        new_attrs[r, c] = 1
  
  return new_attrs

In [16]:
def train_dc_step(batch_size, gen_batch, genc, gdec, dc):
  imgs, attrs, new_attrs = gen_batch

  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  x_a = imgs
  a = attrs
  b = new_attrs

  z = genc.predict(x_a)
  x_b_hat = gdec.predict([z, b])

  dc_real_history = dc.fit([x_a, a], [real, a])
  dc_fake_history = dc.fit([x_b_hat, b], [fake, b])

  return dc_real_history, dc_fake_history

def train_encdec_step(batch_size, gen_batch, combined):
  imgs, attrs, new_attrs = gen_batch

  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  x_a = imgs
  a = attrs
  b = new_attrs

  g_real_history = combined.fit([x_a, a, b], [b, fake, x_a])

  return g_real_history

In [17]:
class HairyGan(): # based on AttGan
  def __init__(self, flags={}):

    self.img_rows = 128
    self.img_cols = 128
    self.img_channels = 3

    self.img_shape = (self.img_rows, self.img_cols, self.img_channels)
    
    patch = int(self.img_rows / 2**4)
    self.disc_out = (patch, patch, 1) # output shape of discriminator

    self.dl = DataLoader(dataset_name='celeba-dataset', img_res=(self.img_rows, self.img_cols))

    self.optimizer = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)

    self.flags = flags

    if flags['filter_on']:
      self.enc = build_encoder(self.img_shape)
      self.dec = build_decoder((4, 4, 1024), self.dl.num_attrs)
      self.dc = build_dc(self.img_shape, self.dl.num_attrs, self.optimizer)
      self.combined = build_combined_generator(self.img_shape, self.dl.num_attrs, self.enc, self.dec, self.dc, self.optimizer)   

    else:
      if os.path.exists(os.path.join(project_path, 'enc.h5')):
        print('Loading encoder from file')
        self.enc = load_model(os.path.join(project_path, 'enc.h5'), custom_objects={'InstanceNormalization': InstanceNormalization})
      else:
        self.enc = build_encoder(self.img_shape)
      
      if os.path.exists(os.path.join(project_path, 'dec.h5')):
        print('Loading decoder from file')
        self.dec = load_model(os.path.join(project_path, 'dec.h5'), custom_objects={'InstanceNormalization': InstanceNormalization})
      else:
        self.dec = build_decoder((4, 4, 1024), self.dl.num_attrs)

      self.dc = build_dc(self.img_shape, self.dl.num_attrs, self.optimizer)
      
      if os.path.exists(os.path.join(project_path, 'dc.weights')):
        print('Loading dc from file')
        self.dc.load_weights(os.path.join(project_path, 'dc.weights'))

      self.combined = build_combined_generator(self.img_shape, self.dl.num_attrs, self.enc, self.dec, self.dc, self.optimizer)   

    self.metrics = {}

  def train(self, num_epochs, batch_size, visualize_interval):
    # set up data loader
    batch_gen = self.dl.load_batch(batch_size=batch_size, is_filter=self.flags['filter_on'])
    for i, elem in enumerate(batch_gen):
      break
  
    num_batches = self.dl.n_batches
    steps_per_epoch = num_batches

    for epoch in range(num_epochs):
      for step in tqdm(range(steps_per_epoch), desc=f'Train {epoch} / {num_epochs}', total=steps_per_epoch):
        imgs, attrs = next(batch_gen)
        new_attrs = create_random_attrs(attrs, batch_size)
        gen_batch = (imgs, attrs, new_attrs)

        dc_real_history, dc_fake_history = train_dc_step(batch_size, gen_batch, self.enc, self.dec, self.dc)
        g_real_history = train_encdec_step(batch_size, gen_batch, self.combined)

        dc_real_history.history['dc_real_loss'] = dc_real_history.history.pop('loss')
        dc_fake_history.history['dc_fake_loss'] = dc_fake_history.history.pop('loss')
        g_real_history.history['g_real_loss'] = g_real_history.history.pop('loss')

        self.metrics = add_metrics(self.metrics, [dc_real_history, dc_fake_history, g_real_history])

        if (step + 1) % visualize_interval == 0:
          try:
            self.sample_images(epoch, step, is_filter=self.flags['filter_on'])
            
            # save models
            # save_model(self.enc, 'enc.h5')
            # shutil.copy('enc.h5', os.path.join(project_path, 'enc.h5'))

            # save_model(self.dec, 'dec.h5')
            # shutil.copy('dec.h5', os.path.join(project_path, 'dec.h5'))

            # self.dc.save_weights('dc.weights')
            # shutil.copy('dc.weights', os.path.join(project_path, 'dc.weights'))

            # self.combined.save_weights('combined.weights')
            # shutil.copy('combined.weights', os.path.join(project_path, 'combined.weights'))

            # visualize loss/accuracy
            visualize_metrics(self.metrics)
          except Exception as e:
            print(e)

  def sample_images(self, epoch, batch_i, is_filter=False):
    print(f'Epoch: {epoch} with batch: {batch_i}')
    rows, cols = 2, 3

    imgs, attrs = self.dl.load_data('test' if not is_filter else 'train_filter', batch_size=2, is_testing=True)

    new_attrs = create_random_attrs(attrs, attrs.shape[0])

    encodings = self.enc.predict(imgs)

    reconstrs = self.dec.predict([encodings, attrs])

    new_imgs = self.dec.predict([encodings, new_attrs])
    # combined.predict([imgs, attrs, new_attrs]) 

    gen_imgs = np.array([imgs[0], new_imgs[0], reconstrs[0], imgs[1], new_imgs[1], reconstrs[1]])

    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axes = plt.subplots(rows, cols)

    count = 0

    for i in range(rows):
      for j in range(cols):
        axes[i, j].imshow(gen_imgs[count])
        axes[i, j].set_title(titles[j])
        axes[i, j].axis('off')
        count += 1

    plt.show()
    

In [9]:
def add_metrics(metrics, histories):
  for history in histories:
    for k, v in history.history.items():
      if metrics.get(k) is None:
        metrics[k] = v
      else:
        metrics[k].append(v[0]) # array of 1 elem => elem
  return metrics

In [10]:
def visualize_metrics(metrics):
  num_plots = len(metrics.keys())

  fig, axes = plt.subplots(num_plots)

  for pl, (title, values) in enumerate(metrics.items()):
    axes[pl].plot(values)
    axes[pl].set_title(title)

  plt.show()

In [None]:
project_path = '/content/drive/My Drive/hairy_gan'
flags = { 'filter_on': True }
gan = HairyGan(flags)
gan.train(num_epochs=30, batch_size=2, visualize_interval=10)

Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_62 (InputLayer)        (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_111 (Conv2D)          (None, 64, 64, 64)        3136      
_________________________________________________________________
batch_normalization_64 (Batc (None, 64, 64, 64)        256       
_________________________________________________________________
leaky_re_lu_112 (LeakyReLU)  (None, 64, 64, 64)        0         
_________________________________________________________________
conv2d_112 (Conv2D)          (None, 32, 32, 128)       131200    
_________________________________________________________________
batch_normalization_65 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
leaky_re_lu_113 (LeakyReLU)  (None, 32, 32, 128)       0   

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.



Train 0 / 30:   0%|          | 0/5 [00:00<?, ?it/s][A[A[A

Model: "combined"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_67 (InputLayer)           (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
encoder (Model)                 (None, 4, 4, 1024)   11154112    input_67[0][0]                   
__________________________________________________________________________________________________
input_69 (InputLayer)           (None, 1)            0                                            
__________________________________________________________________________________________________
decoder (Model)                 (None, 128, 128, 3)  44612995    encoder[1][0]                    
                                                                 input_69[0][0]            

  'Discrepancy between trainable weights and collected trainable'


Epoch 1/1
Epoch 1/1
Epoch 1/1





Train 0 / 30:  20%|██        | 1/5 [00:37<02:30, 37.60s/it][A[A[A

Epoch 1/1
Epoch 1/1
Epoch 1/1





Train 0 / 30:  40%|████      | 2/5 [00:46<01:26, 28.96s/it][A[A[A

Epoch 1/1
Epoch 1/1
Epoch 1/1





Train 0 / 30:  60%|██████    | 3/5 [00:55<00:45, 22.88s/it][A[A[A

Epoch 1/1
Epoch 1/1
Epoch 1/1





Train 0 / 30:  80%|████████  | 4/5 [01:03<00:18, 18.61s/it][A[A[A

Epoch 1/1
Epoch 1/1
Epoch 1/1





Train 0 / 30: 100%|██████████| 5/5 [01:12<00:00, 14.48s/it]



Train 1 / 30:   0%|          | 0/5 [00:00<?, ?it/s][A[A[A

Epoch 1/1
Epoch 1/1
Epoch 1/1





Train 1 / 30:  20%|██        | 1/5 [00:08<00:34,  8.67s/it][A[A[A

In [None]:
gan.sample_images(1, 1)