<a href="https://colab.research.google.com/github/GarlandZhang/hairy_gan/blob/master/hairy_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from google.colab import files
import os
from IPython.display import clear_output

import pandas as pd
import os
import shutil
if not os.path.exists('kaggle.json'):
  shutil.copy('/content/drive/My Drive/hairy_gan/kaggle.json', 'kaggle.json')
  # !pip install -q kaggle
  # files.upload()
  !mkdir -p ~/.kaggle
  !cp kaggle.json ~/.kaggle/
  !kaggle datasets download -d jessicali9530/celeba-dataset --force
  !unzip celeba-dataset.zip
  !mv img_align_celeba celeba-dataset
  !mv list_eval_partition.csv celeba-dataset/list_eval_partition.csv
  !mv list_landmarks_align_celeba.csv celeba-dataset/list_landmarks_align_celeba.csv
  !mv list_attr_celeba.csv celeba-dataset/list_attr_celeba.csv
  !mv list_bbox_celeba.csv celeba-dataset/list_bbox_celeba.csv

  !mkdir celeba-dataset/train
  !mkdir celeba-dataset/validation
  !mkdir celeba-dataset/test

  partitions_df = pd.read_csv('celeba-dataset/list_eval_partition.csv') # 0 => train, 1 => validation, 2 => test
  for i, set_name in enumerate(['train', 'validation', 'test']):
    set_ids_df = partitions_df.loc[partitions_df['partition'] == i]['image_id']
    set_ids = set_ids_df.tolist()
    for id in set_ids:
      shutil.copy(os.path.join('celeba-dataset/img_align_celeba', id), os.path.join('celeba-dataset', f'{set_name}', id))

  !git clone https://www.github.com/keras-team/keras-contrib.git \
    && cd keras-contrib \
    && pip install git+https://www.github.com/keras-team/keras-contrib.git \
    && python convert_to_tf_keras.py \
    && USE_TF_KERAS=1 python setup.py install

  !pip install scipy==1.1.0

  clear_output()

In [None]:
from __future__ import print_function, division
import scipy
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, Embedding
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.optimizers import Adam
from keras.models import load_model, save_model

import keras.backend as K
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
from glob import glob
import math
from PIL import Image

import tensorflow as tf
from tensorflow.python.keras.backend import set_session, clear_session
# from tensorflow.python.keras.models import load_model
# tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_v2_behavior()
# tf.compat.v1.enable_eager_execution()
import cv2
import copy

from tqdm import tqdm

In [None]:
!rm -r celeba-dataset/train_filter
!mkdir celeba-dataset/train_filter
# extract images of particular class for training
num_images_each = 50
num_train_images = num_images_each
feature = 'Eyeglasses'
complete_df = pd.read_csv('celeba-dataset/list_attr_celeba.csv')
for img_id in complete_df.loc[complete_df[feature] == 1][:num_images_each].filter(['image_id']).to_numpy():
  img_id = img_id[0]
  shutil.copy(f'celeba-dataset/train/{img_id}', f'celeba-dataset/train_filter/{img_id}')

for img_id in complete_df.loc[complete_df[feature] == -1][:num_images_each].filter(['image_id']).to_numpy():
  img_id = img_id[0]
  shutil.copy(f'celeba-dataset/train/{img_id}', f'celeba-dataset/train_filter/{img_id}')

In [None]:
!rm -r celeba-dataset/validation_set
!mkdir celeba-dataset/validation_set
# extract images of particular class for training
num_images_each = 20
offset = num_train_images
num_validation_images = num_images_each
feature = 'Eyeglasses'
complete_df = pd.read_csv('celeba-dataset/list_attr_celeba.csv')
for img_id in complete_df.loc[complete_df[feature] == 1][offset:offset + num_images_each].filter(['image_id']).to_numpy():
  img_id = img_id[0]
  shutil.copy(f'celeba-dataset/train/{img_id}', f'celeba-dataset/validation_set/{img_id}')

for img_id in complete_df.loc[complete_df[feature] == -1][offset:offset + num_images_each].filter(['image_id']).to_numpy():
  img_id = img_id[0]
  shutil.copy(f'celeba-dataset/train/{img_id}', f'celeba-dataset/validation_set/{img_id}')

In [None]:
class DataLoader():
    def __init__(self, dataset_name, img_res):
        self.dataset_name = dataset_name
        self.img_res = img_res
        self.complete_df = pd.read_csv('celeba-dataset/list_attr_celeba.csv')
        # self.features = ['Bald', 'Bangs', 'Eyeglasses', 'Mustache', 'No_Beard', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Hat']
        self.features = ['Eyeglasses']
        self.num_attrs = len(self.features)

    def load_data(self, dataset_type, batch_size=1, is_testing=False):
        data_type = dataset_type
        path = glob('%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        attribs = []
        
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = scipy.misc.imresize(img, self.img_res)
            imgs.append(img)

            # get attributes

            img_attribs = [(val + 1) // 2 for val in self.complete_df.loc[self.complete_df['image_id'] == os.path.basename(img_path)].filter(items=self.features).to_numpy()[0]]

            attribs.append(img_attribs)

        imgs = np.array(imgs)/127.5 - 1.
        attribs = np.array(attribs)

        return imgs, attribs

    def load_batch(self, batch_size=1, is_testing=False, is_filter=False):
        if is_filter:
          data_type = 'train_filter'
        elif is_testing:
          data_type = 'test'
        else:
          data_type = 'train'
        path = glob('%s/%s/*' % (self.dataset_name, data_type))

        self.n_batches = int(len(path) / batch_size)
        total_samples = self.n_batches * batch_size

        path = np.random.choice(path, total_samples, replace=False)

        i = 0
        while i < self.n_batches - 1:
            batch = path[i*batch_size:(i+1)*batch_size]
            imgs = []
            attribs = []
            for img_path in batch:
                img = self.imread(img_path)

                img = scipy.misc.imresize(img, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img = np.fliplr(img)

                imgs.append(img)

                # get attributes

                img_attribs = np.array([(val + 1) // 2 for val in self.complete_df.loc[self.complete_df['image_id'] == os.path.basename(img_path)].filter(items=self.features).to_numpy()[0]])

                attribs.append(img_attribs)

            imgs = np.array(imgs)/127.5 - 1.
            attribs = np.array(attribs)

            i += 1
            if i == self.n_batches - 1:
              # reset
              path = glob('%s/%s/*' % (self.dataset_name, data_type))
              path = np.random.choice(path, total_samples, replace=False)
              i = 0

            yield imgs, attribs

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [None]:
def build_encoder(img_shape, num_filters=64, kernel_size=4, strides=2):
  def build_conv(x, num_filters, kernel_size, strides):
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x

  img = Input(shape=img_shape)
  x = build_conv(img, num_filters, kernel_size, strides)
  x = build_conv(x, num_filters * 2, kernel_size, strides)
  x = build_conv(x, num_filters * 4, kernel_size, strides)
  x = build_conv(x, num_filters * 8, kernel_size, strides)
  x = build_conv(x, num_filters * 16, kernel_size, strides)
  # x.name = 'encoder_output'

  model = Model(img, x, name='encoder')

  model.summary()

  return model

def build_embedding(img, label, input_shape, attr_size):
  label_embedding = Embedding(2, np.prod(input_shape), input_length=attr_size)(label)
  # style_embedding = Embedding(2, np.prod(input_shape), input_length=attr_size)(style)
  # label_style_embedding = Add()([label_embedding, style_embedding])
  # label_style_embedding = Reshape(input_shape[:-1] + (attr_size * input_shape[-1], ))(label_style_embedding)
  # emb_img = Concatenate(axis=-1)([img, label_style_embedding])
  label_embedding = Reshape(input_shape[:-1] + (attr_size * input_shape[-1], ))(label_embedding)
  emb_img = Concatenate(axis=-1)([img, label_embedding])
  return emb_img

def build_decoder(latent_space_shape, attr_size, num_filters=64, kernel_size=4, strides=1):
  def build_deconv(x, num_filters, kernel_size, strides):
    x = UpSampling2D(size=2)(x)
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x

  img = Input(shape=latent_space_shape)
  label = Input(shape=(attr_size, ), dtype='int32')

  emb_img = build_embedding(img, label, latent_space_shape, attr_size)

  x = build_deconv(emb_img, num_filters * 16, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 8, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 4, kernel_size=kernel_size, strides=strides)
  x = build_deconv(x, num_filters * 2, kernel_size=kernel_size, strides=strides)
  x = UpSampling2D(size=2)(x)
  x = Conv2D(3, kernel_size=kernel_size, strides=strides, padding='same', activation='tanh')(x)
  # x.name = 'decoder_output'

  model = Model([img, label], x, name='decoder')

  model.summary()

  return model

def build_convnet(img, num_filters=64, kernel_size=4, strides=2):
  def build_conv(x, num_filters, kernel_size, strides):
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    return x
  
  x = build_conv(img, num_filters, kernel_size, strides)
  x = build_conv(x, num_filters * 2, kernel_size, strides)
  x = build_conv(x, num_filters * 4, kernel_size, strides)
  x = build_conv(x, num_filters * 8, kernel_size, strides)
  x = build_conv(x, num_filters * 16, kernel_size, strides)
  x = Flatten()(x)
  x = Dense(1024)(x)
  x = InstanceNormalization()(x)
  x = LeakyReLU()(x)

  return x

def disc_loss_fn(y_true, y_pred):
  return tf.reduce_mean(y_pred) - tf.reduce_mean(y_true)

def build_discriminator(img_shape, optimizer):
  img = Input(shape=img_shape)
  x = build_convnet(img, num_filters=8)
  output = Dense(1, name='disc_output', activation='sigmoid')(x)

  disc = Model(img, output, name='discriminator')
  disc.compile(loss='binary_crossentropy', optimizer=optimizer)

  disc.summary()

  return disc

def build_classifier(img_shape, attr_size, optimizer):
  img = Input(shape=img_shape)
  x = build_convnet(img)

  output = Dense(attr_size, activation='sigmoid', name='classif_output')(x)

  classif = Model(img, output, name='classifier')

  classif.compile(loss='binary_crossentropy', optimizer=optimizer)

  classif.summary()

  return classif

def gen_loss_fn(y_true, y_pred):
  return - tf.reduce_mean(y_pred)

def build_combined_generator(img_shape, attr_size, genc, gdec, classifier, discriminator, optimizer):
  classifier.trainable = False
  discriminator.trainable = False

  x_a = Input(shape=img_shape) # original image
  a = Input(shape=(attr_size, )) # original attr
  b = Input(shape=(attr_size, )) # requested attr
  
  z = genc(x_a) # latent space representation of original image
  x_b = gdec([z, b]) # image with requested attr

  b_hat = classifier(x_b) # guess attributes
  valid = discriminator(x_b) # guess real or fake

  x_a_hat = gdec([z, a]) # reconstr

  combined = Model(
      inputs=[x_a, a, b],
      outputs=[b_hat, valid, x_a_hat], # second output is adversarial loss
      name='combined'
  )

  combined.compile(loss=['binary_crossentropy', 'binary_crossentropy', 'mae'], loss_weights=[10, 1, 100], optimizer=optimizer)

  combined.summary()

  return combined

In [None]:
def shuffle(elems):
  new_elems = elems.copy()
  np.random.shuffle(new_elems)
  return new_elems

def create_random_attrs(attrs):
  # new_attrs = np.ones((attrs.shape))
  # count = attrs.shape[0]
  # attr_size = attrs[0].size
  
  new_attrs = np.random.randint(0, 2, size=attrs.shape)
  # for r in range(count):
  #   for c in range(attr_size):
  #     if attrs[r, c] == 1 and new_attrs[r, c] == 0:
  #       new_attrs[r, c] = 1
  
  return new_attrs

In [None]:
def train_classifier_step(gen_batch, classifier):
  imgs, attrs = gen_batch

  x_a = imgs
  a = attrs

  classif_history = classifier.fit(x_a, a)

  return classif_history

def train_discriminator_step(batch_size, gen_batch, enc, dec, discriminator):
  imgs, attrs, new_attrs = gen_batch

  x_a = imgs
  a = attrs
  b = new_attrs

  z = enc.predict(x_a)
  x_b = dec.predict([z, b])

  real = np.ones((len(imgs), 1))
  fake = np.zeros((len(imgs), 1))
  # disc_loss = discriminator.train_on_batch(x_b, discriminator.predict(x_a))
  disc_real_loss = discriminator.train_on_batch(x_a, real)
  disc_fake_loss = discriminator.train_on_batch(x_b, fake)
  # disc_real_history = discriminator.fit(x_a, real)
  # disc_fake_history = discriminator.fit(x_b, fake)

  class TempObj():
    def __init__(self, value):
      self.history = {}
      self.history['loss'] = [value]

  disc_real_history = TempObj(disc_real_loss)
  disc_fake_history = TempObj(disc_fake_loss)

  return disc_real_history, disc_fake_history

def train_encdec_step(batch_size, gen_batch, combined):
  imgs, attrs, new_attrs = gen_batch

  real = np.ones((batch_size, 1))
  # fake = np.zeros((batch_size, 1))

  x_a = imgs
  a = attrs
  b = new_attrs

  g_real_loss = combined.train_on_batch([x_a, a, b], [b, real, x_a])
  # g_real_history = combined.fit([x_a, a, b], [b, real, x_a])


  class TempObj():
    def __init__(self, value):
      self.history = {}
      self.history['loss'] = [value]

  g_real_history = TempObj(g_real_loss)
  return g_real_history

In [None]:
def add_metrics(metrics, histories):
  for history in histories:
    for k, v in history.history.items():
      if metrics.get(k) is None:
        metrics[k] = v
      else:
        metrics[k].append(v[0]) # array of 1 elem => elem
  return metrics

In [None]:
def visualize_metrics(metrics):
  num_plots = len(metrics.keys())

  fig, axes = plt.subplots(num_plots)

  for pl, (title, values) in enumerate(metrics.items()):
    axes[pl].plot(values)
    axes[pl].set_title(title)

  plt.show()

In [None]:
def classif_build_conv(x, num_filters, kernel_size, strides):
  x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
  x = InstanceNormalization()(x)
  x = LeakyReLU()(x)
  return x

In [None]:
def tag_name(dic):
  name = ''
  for key, val in dic.items():
    name = name + '_' + key + '_' + str(val)
  return name[1:]

In [None]:
class HairyGan(): # based on AttGan
  def __init__(self, flags={}):

    self.learning_rate = flags['learning_rate']

    self.enc_dec_interval = flags['enc_dec_interval']
    self.disc_interval = flags['disc_interval']

    self.img_rows = flags['orig_dim']
    self.img_cols = flags['orig_dim']
    self.img_channels = 3

    self.img_shape = (self.img_rows, self.img_cols, self.img_channels)
    
    patch = int(self.img_rows / 2**4)
    self.disc_out = (patch, patch, 1) # output shape of discriminator

    self.dl = DataLoader(dataset_name='celeba-dataset', img_res=(self.img_rows, self.img_cols))

    self.optimizer = Adam(learning_rate=self.learning_rate, beta_1=0.5, beta_2=0.999)

    self.flags = flags

    if self.flags['from_scratch']:
      self.enc = build_encoder(self.img_shape)
      self.dec = build_decoder((4, 4, 1024), self.dl.num_attrs)
      self.disc = build_discriminator(self.img_shape, self.optimizer)
    else:
      if os.path.exists(os.path.join(project_path, 'enc.h5')):
        print('Loading encoder from file')
        self.enc = load_model(os.path.join(project_path, 'enc.h5'), custom_objects={'InstanceNormalization': InstanceNormalization})
      else:
        self.enc = build_encoder(self.img_shape)
      
      if os.path.exists(os.path.join(project_path, 'dec.h5')):
        print('Loading decoder from file')
        self.dec = load_model(os.path.join(project_path, 'dec.h5'), custom_objects={'InstanceNormalization': InstanceNormalization})
      else:
        self.dec = build_decoder((4, 4, 1024), self.dl.num_attrs)
      
      # if os.path.exists(os.path.join(project_path, 'disc.weights')):
      #   print('Loading disc from file')
      #   self.disc.load_weights(os.path.join(project_path, 'disc.weights'))

      if os.path.exists(os.path.join(project_path, 'disc.h5')):
        print('Loading disc from file')
        self.disc = load_model(os.path.join(project_path, 'disc.h5'), custom_objects={'InstanceNormalization': InstanceNormalization})
      else:
        self.disc = build_discriminator(self.img_shape, self.optimizer)

    # if os.path.exists(os.path.join(project_path, 'classif.weights')):
    #   print('Loading classif from file')
    #   self.classif.load_weights(os.path.join(project_path, 'classif.weights'))

    if os.path.exists(os.path.join(project_path, 'classif.h5')):
      print('Loading classif from file')
      self.classif = load_model(os.path.join(project_path, 'classif.h5'), custom_objects={'InstanceNormalization': InstanceNormalization})
    else:
      self.classif = build_classifier(self.img_shape, self.dl.num_attrs, self.optimizer)

    self.combined = build_combined_generator(self.img_shape, self.dl.num_attrs, self.enc, self.dec, self.classif, self.disc, self.optimizer)   

    self.metrics = {}

    if flags['new_dim'] != flags['orig_dim']: # progress!
      self.progress(flags['new_dim'])

  def progress(self, new_dim):
    print('Applying progression')

    orig_dim = self.img_shape[0]
    input = Input(shape=(new_dim, new_dim, 3))
    dim_scale = new_dim / orig_dim
    num_new_layers = round(math.log(dim_scale, 2))

    # update data loader
    print('Update data loader')
    self.dl = DataLoader(dataset_name='celeba-dataset', img_res=(new_dim, new_dim))

    # update classifier
    print('Update classifier')
    layers = self.classif.layers[2:]
    output = classif_build_conv(input, 64, 4, 2)
    output = classif_build_conv(output, 64, 4, 2)
    for i, layer in enumerate(layers):
      output = layer(output)
    new_classif = Model(input=input, output=output, name='classifier')
    new_classif.compile(loss='binary_crossentropy', optimizer=self.optimizer)
    self.classif = new_classif

  def pretrain_classifier(self, num_epochs, batch_size, visualize_interval):
    self.classif.trainable = True

    # set up data loader
    batch_gen = self.dl.load_batch(batch_size=batch_size, is_filter=self.flags['filter_on'])
    for i, elem in enumerate(batch_gen):
      break
  
    num_batches = self.dl.n_batches
    steps_per_epoch = num_batches

    count = 0

    for epoch in range(num_epochs):
      for step in tqdm(range(steps_per_epoch), desc=f'Train {epoch} / {num_epochs}', total=steps_per_epoch):
        gen_batch = next(batch_gen)

        classif_history = train_classifier_step(gen_batch, self.classif)
        classif_history.history['classif_loss'] = classif_history.history.pop('loss')

        self.metrics = add_metrics(self.metrics, [classif_history])

        clear_output()
        if (count + 1) % visualize_interval == 0:
          try:
            self.sample_images(epoch, step, is_filter=self.flags['filter_on'])
            
            # save model
            save_model(self.classif, os.path.join(project_path, 'classif.h5'))
            # self.classif.save_weights('classif.weights')
            # shutil.copy('classif.weights', os.path.join(project_path, 'classif.weights'))

            # visualize loss/accuracy
            visualize_metrics(self.metrics)
          except Exception as e:
            print(e)
        
        count += 1

  def train(self, num_epochs, batch_size, visualize_interval):
    self.classif.trainable = False

    # set up data loader
    batch_gen = self.dl.load_batch(batch_size=batch_size, is_filter=self.flags['filter_on'])
    for i, elem in enumerate(batch_gen):
      break
  
    num_batches = self.dl.n_batches
    steps_per_epoch = num_batches

    count = 0

    for epoch in range(num_epochs):
      for step in tqdm(range(steps_per_epoch), desc=f'Train {epoch} / {num_epochs}', total=steps_per_epoch):
        imgs, attrs = next(batch_gen)
        new_attrs = create_random_attrs(attrs)
        gen_batch = (imgs, attrs, new_attrs)

        print('Train discriminator')
        self.disc.trainable = True
        for i in range(self.disc_interval):
          disc_real_history, disc_fake_history = train_discriminator_step(batch_size, gen_batch, self.enc, self.dec, self.disc)

        print('Train encoder/decoder')
        self.disc.trainable = False
        for i in range(self.enc_dec_interval):
          g_real_history = train_encdec_step(batch_size, gen_batch, self.combined)

        # disc_real_history.history['disc_real_loss'] = disc_real_history.history.pop('loss')
        # disc_fake_history.history['disc_fake_loss'] = disc_fake_history.history.pop('loss')
        g_real_history.history['g_real_loss'] = g_real_history.history.pop('loss')

        # self.metrics = add_metrics(self.metrics, [disc_real_history, disc_fake_history, g_real_history])
        self.metrics = add_metrics(self.metrics, [disc_real_history, g_real_history])

        # set to trainable again
        self.disc.trainable = True

        clear_output()
        if (count + 1) % visualize_interval == 0:
          try:
            self.sample_images(epoch, step, is_filter=self.flags['filter_on'])
            
            # save models
            save_model(self.enc, 'enc.h5')
            shutil.copy('enc.h5', os.path.join(project_path, 'enc.h5'))
            # shutil.copy('enc.h5', os.path.join(project_path, 'backup', 'enc.h5'))

            save_model(self.dec, 'dec.h5')
            shutil.copy('dec.h5', os.path.join(project_path, 'dec.h5'))
            # shutil.copy('dec.h5', os.path.join(project_path, 'backup', 'dec.h5'))

            save_model(self.disc, os.path.join(project_path, 'disc.h5'))
            # save_model(self.combined, os.path.join(project_path, 'combined.h5'))
            # self.disc.save_weights('disc.weights')
            # shutil.copy('disc.weights', os.path.join(project_path, 'disc.weights'))
            # shutil.copy('disc.weights', os.path.join(project_path, 'backup', 'disc.weights'))

            # self.combined.save_weights('combined.weights')
            # shutil.copy('combined.weights', os.path.join(project_path, 'combined.weights'))
            # shutil.copy('combined.weights', os.path.join(project_path, 'backup', 'combined.weights'))

            # visualize loss/accuracy
            visualize_metrics(self.metrics)
          except Exception as e:
            print(e)
        
        count += 1

  def sample_images(self, epoch, batch_i, is_filter=False):
    print(f'Epoch: {epoch} with batch: {batch_i}')
    rows, cols = 2, 3

    imgs, attrs = self.dl.load_data('test' if not is_filter else 'train_filter', batch_size=2, is_testing=True)

    new_attrs = create_random_attrs(attrs)

    encodings = self.enc.predict(imgs)

    reconstrs = self.dec.predict([encodings, attrs])

    new_imgs = self.dec.predict([encodings, new_attrs])
    # combined.predict([imgs, attrs, new_attrs]) 

    gen_imgs = np.array([imgs[0], new_imgs[0], reconstrs[0], imgs[1], new_imgs[1], reconstrs[1]])

    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axes = plt.subplots(rows, cols)

    count = 0

    print(f'O.G. images fake or real: {self.disc.predict(imgs)}')
    print(f'New images fake or real: {self.disc.predict(new_imgs)}')
    print(f'Reconstructed images fake or real: {self.disc.predict(reconstrs)}')
    print(f'O.G. image attributes: {self.classif.predict(imgs)}')
    print(f'New image attributes: {self.classif.predict(new_imgs)}')
    print(f'Reconstructed image attributes: {self.classif.predict(reconstrs)}')

    for i in range(rows):
      for j in range(cols):
        axes[i, j].imshow(gen_imgs[count])
        axes[i, j].set_title(titles[j])
        axes[i, j].axis('off')
        count += 1

    plt.show()
    

In [None]:
flag_options = [
  # { 'filter_on': True, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 1, 'disc_interval': 1, 'learning_rate': 0.000005 },
  # { 'filter_on': True, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 5, 'disc_interval': 1, 'learning_rate': 0.000005 },
  # { 'filter_on': True, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 1, 'disc_interval': 1, 'learning_rate': 0.00005 },
  # { 'filter_on': True, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 3, 'disc_interval': 1, 'learning_rate': 0.00005 },
  # { 'filter_on': True, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 1, 'disc_interval': 1, 'learning_rate': 0.0005 },
  # { 'filter_on': True, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 1, 'disc_interval': 1, 'learning_rate': 0.005 },
  # { 'filter_on': False, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 1, 'disc_interval': 1, 'learning_rate': 0.000005 },
]

og_project_path = '/content/drive/My Drive/hairy_gan'
classif = load_model(os.path.join(og_project_path, 'classif.h5'), custom_objects={'InstanceNormalization': InstanceNormalization})

tags, losses = [], []
best_tag, best_loss = None, None

for flags in flag_options:
  # create tag name
  tag = tag_name(flags)

  # set project path
  project_path = os.path.join(og_project_path, tag)

  # create folder
  if not os.path.exists(project_path):
    os.makedirs(project_path)

  # create gan model
  gan = HairyGan(flags)
  del gan.classif
  gan.classif = classif

  # # train gan model
  # # if not flags['filter_on']:
  # #   gan.train(num_epochs=1, batch_size=32, visualize_interval=100)
  # # else:
  # #   gan.train(num_epochs=160, batch_size=2, visualize_interval=100)

  # # calculate loss
  # batch_size = num_validation_images
  # imgs, attrs = gan.dl.load_data('validation_set', batch_size=batch_size, is_testing=True)
  # new_attrs = np.ones(attrs.shape)
  # encodings = gan.enc.predict(imgs)
  # reconstrs = gan.dec.predict([encodings, attrs])
  # new_imgs = gan.dec.predict([encodings, new_attrs])

  # real = np.ones(shape=(batch_size,))
  # loss = gan.combined.evaluate([imgs, attrs, new_attrs], [new_attrs, real, imgs])[0]

  # # save best tag, loss
  # if best_loss is None or loss < best_loss:
  #   best_tag = tag
  #   best_loss = loss

  # tags.append(tag)
  # losses.append(loss)

  # # delete gan model
  # del gan

In [None]:
project_path = '/content/drive/My Drive/hairy_gan'
flags = { 'from_scratch': True, 'filter_on': True, 'orig_dim': 128, 'new_dim': 128, 'enc_dec_interval': 3, 'disc_interval': 1, 'learning_rate': 0.0005 }
gan = HairyGan(flags)

In [None]:
# gan.pretrain_classifier(num_epochs=10000, batch_size=2, visualize_interval=100)
if not flags['filter_on']:
  gan.train(num_epochs=10000, batch_size=32, visualize_interval=100)
else:
  gan.train(num_epochs=10000, batch_size=2, visualize_interval=100)

In [None]:
rows, cols = 2, 3

imgs, attrs = gan.dl.load_data('train_filter', batch_size=2, is_testing=True)

new_attrs = np.ones(attrs.shape)
# new_attrs = np.zeros(attrs.shape)
# new_attrs[:, 2] = 1

encodings = gan.enc.predict(imgs)

reconstrs = gan.dec.predict([encodings, attrs])

new_imgs = gan.dec.predict([encodings, new_attrs])
# combined.predict([imgs, attrs, new_attrs]) 

print(f'O.G. images fake or real: {gan.disc.predict(imgs)}')
print(f'New images fake or real: {gan.disc.predict(new_imgs)}')
print(f'Reconstructed images fake or real: {gan.disc.predict(reconstrs)}')
print(f'O.G. image attributes: {gan.classif.predict(imgs)}')
print(f'New image attributes: {gan.classif.predict(new_imgs)}')
print(f'Reconstructed image attributes: {gan.classif.predict(reconstrs)}')

gen_imgs = np.array([imgs[0], new_imgs[0], reconstrs[0], imgs[1], new_imgs[1], reconstrs[1]])

gen_imgs = 0.5 * gen_imgs + 0.5

titles = ['Original', 'Translated', 'Reconstructed']
fig, axes = plt.subplots(rows, cols)

count = 0

for i in range(rows):
  for j in range(cols):
    axes[i, j].imshow(gen_imgs[count])
    axes[i, j].set_title(titles[j])
    axes[i, j].axis('off')
    count += 1

visualize_metrics(gan.metrics)

plt.show()

In [None]:
# save_model(gan.enc, 'enc.h5')
# shutil.copy('enc.h5', os.path.join(project_path, 'enc.h5'))
# # shutil.copy('enc.h5', os.path.join(project_path, 'backup', 'enc.h5'))

# save_model(gan.dec, 'dec.h5')
# shutil.copy('dec.h5', os.path.join(project_path, 'dec.h5'))
# # shutil.copy('dec.h5', os.path.join(project_path, 'backup', 'dec.h5'))

# save_model(gan.disc, os.path.join(project_path, 'disc.h5'))

In [None]:
img_shape = (128, 128, 3)
# all_ones = np.ones(img_shape)
all_ones = imgs[0]
# all_zeros = np.zeros(img_shape)
all_zeros = new_imgs[0]

In [None]:
plt.imshow(all_zeros)

In [None]:
def disc_loss_fn2(y_true, y_pred):
  return tf.reduce_mean(y_pred) - tf.reduce_mean(y_true)

def build_discriminator2(img_shape, optimizer):
  def build_conv(x, num_filters, kernel_size, strides):
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    return x

  img = Input(shape=img_shape)
  
  num_filters = 8
  kernel_size = 4
  strides = 2

  x = build_conv(img, num_filters, kernel_size, strides)
  x = build_conv(x, num_filters * 2, kernel_size, strides)
  x = build_conv(x, num_filters * 4, kernel_size, strides)
  x = build_conv(x, num_filters * 8, kernel_size, strides)
  x = build_conv(x, num_filters * 16, kernel_size, strides)
  # x = build_conv(x, num_filters * 2, kernel_size, strides)
  # x = build_conv(x, num_filters * 4, kernel_size, strides)
  # x = build_conv(x, num_filters * 8, kernel_size, strides)
  # x = build_conv(x, num_filters * 16, kernel_size, strides)
  x = Flatten()(x)
  # x = Flatten()(img)
  x = Dense(1024)(x)
  x = InstanceNormalization()(x)
  x = LeakyReLU()(x)
  output = Dense(1, name='disc_output', activation='sigmoid')(x)

  disc = Model(img, output, name='discriminator')
  disc.compile(loss='binary_crossentropy', optimizer=optimizer)

  disc.summary()

  return disc

In [None]:
# # del disc
# disc = build_discriminator2(img_shape, Adam(5e-04))
# clear_output()
# disc.summary()

In [None]:
for i in range(10):
  # loss = disc.train_on_batch(np.array([imgs[0], imgs[1], new_imgs[0], new_imgs[1]]), [1, 1, 0, 0])
  loss = disc.train_on_batch(np.array([all_ones, all_zeros]), [1, 0])
  # loss = disc.train_on_batch(np.expand_dims(all_ones, axis=0), [1]) # dont uncomment this; this overfits
  clear_output()

In [None]:
print(f'O.G. images fake or real: {disc.predict(np.array([all_ones]))}')
print(f'New images fake or real: {disc.predict(np.array([all_zeros]))}')
# print(f'O.G. images fake or real: {disc.predict(imgs)}')
# print(f'New images fake or real: {disc.predict(new_imgs)}')