In [15]:
## Referenced the code from https://github.com/hoangthang1607/StarGAN-Keras/blob/master/StarGAN.py
## and https://github.com/yunjey/stargan/
## utils.py is for preprocessing of the input image and has been borrowed from 
## https://github.com/hoangthang1607/StarGAN-Keras/blob/master/utils.py

In [16]:
# example of loading the cifar10 dataset
from keras.datasets.cifar10 import load_data
from keras.models import Sequential, model_from_json
from keras.layers import Input
from keras.layers import Conv2D, UpSampling2D, ZeroPadding2D, Concatenate, Dropout, LeakyReLU, BatchNormalization, ReLU, Conv2DTranspose, Add
from keras.layers import Lambda
from keras.optimizers import Adam
from keras.layers import Reshape
from keras.layers import Conv2DTranspose
from keras.models import Model, load_model
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers.merge import _Merge
from keras import backend as K
from functools import partial
# example of training the discriminator model on real and random cifar10 images
import numpy
import csv
import random
import time
from utils import *
from skimage.transform import resize
from scipy.linalg import sqrtm
import matplotlib.pyplot as pyplot
import tensorflow as tf
import os
import shutil
from PIL import Image
import numpy as np

In [6]:
class RandomWeightedAverage(_Merge):
    # Given in paper, Section 4.1 Implementation of the paper in which the x is to be 
    # sampled uniformly along a stright line between a pair of real and generated images
    """Provides a (random) weighted average between real and generated image samples"""
    def define_batch_size(self, bs):
        self.bs = bs
    def _merge_function(self, inputs):
        alpha = K.random_uniform((4, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [7]:
# Very deep neural networks are hard to train as they are more prone to vanishing or exploding gradients. 
# To solve this problem, the activation unit from a layer could be fed directly to a deeper layer of the network, 
# which is termed as a skip connection.

def residual_block(inp, dim_out):
    x = ZeroPadding2D(padding = 1)(inp)
    x = Conv2D(dim_out, kernel_size = (3,3), strides=(1,1), padding='valid', bias= False)(x)
    x = InstanceNormalization(axis = -1)(x)
    x = ReLU()(x)
    x = ZeroPadding2D(padding = 1)(x)
    x = Conv2D(dim_out, kernel_size = (3,3), strides=(1,1), padding='valid', bias= False)(x)
    x = InstanceNormalization(axis = -1)(x)
    return Add()([inp, x])

In [8]:
def generator(conv_dim=64, repeat_num=6, image_size = 128, c_dim = 5):
    generator = Sequential()
    #input_domain is the domain labels of the attruites ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
    input_domain = Input(shape = (c_dim, ))
    #input image size of 128 X 128
    input_image = Input(shape = (image_size, image_size, 3))   
    # Concatanating the domain information with the image
    c = Lambda(lambda x: K.repeat(x, image_size**2))(input_domain)
    c = Reshape((image_size, image_size, c_dim))(c)
    x = Concatenate()([input_image, c])
    #the model is prepared as per the model summart given in the paper
    # starting with the first convolution
    x = Conv2D(conv_dim, kernel_size = (7,7), strides=(1,1), padding='same', bias=False)(x)
    x = InstanceNormalization(axis = -1)(x)
    x = ReLU()(x)
    curr_dim = conv_dim
    for i in range(2):
        x = ZeroPadding2D(padding = 1)(x)
        x = Conv2D(curr_dim*2, kernel_size=(4,4), strides=(2,2), padding='valid', bias=False)(x)
        x = InstanceNormalization(axis = -1)(x)
        x = ReLU()(x)
        curr_dim = curr_dim * 2
    # Bottleneck layers.
    for i in range(repeat_num):
        x = residual_block(inp=x, dim_out=curr_dim)
    # Up-sampling layers.
    for i in range(2):
        x = UpSampling2D(size = 2)(x)
        x = Conv2D(curr_dim//2, kernel_size=(4,4), strides=(1,1), padding='same', bias=False)(x)
        x = InstanceNormalization(axis = -1)(x)
        x = ReLU()(x)
        curr_dim = curr_dim // 2
    x = ZeroPadding2D(padding = 3)(x)
    o = Conv2D(filters = 3, kernel_size = 7, strides = 1, padding = 'valid', activation = 'tanh', use_bias = False)(x)
    return Model(inputs = [input_image, input_domain], outputs = o)

In [9]:
def discriminator(in_shape=(128,128,3), conv_dim=64, image_size=128, repeat_num=6, c_dim=5):
    #input dimension of the input image
    input_image = Input(shape = (image_size, image_size, 3))
    #first layer of convolution
    x = ZeroPadding2D(padding = 1)(input_image)
    x = Conv2D(filters = conv_dim, kernel_size = (4,4), strides = (2,2), padding = 'valid', use_bias = False)(x)
    x = LeakyReLU(alpha=0.01)(x)
    curr_dim = conv_dim
    for i in range(1, repeat_num):
        x = ZeroPadding2D(padding = 1)(x)
        x = Conv2D(filters = curr_dim*2, kernel_size = (4,4), strides = (2,2), padding = 'valid')(x)
        x = LeakyReLU(alpha=0.01)(x)
        curr_dim = curr_dim * 2
    kernel_size = int(image_size / numpy.power(2, repeat_num))
    out_src = ZeroPadding2D(padding = 1)(x)
    out_src = Conv2D(filters = 1, kernel_size = (3,3), strides = (1,1), padding= 'valid', use_bias = False)(out_src)
    out_cls = Conv2D(filters = c_dim, kernel_size = kernel_size, strides = (1,1), padding= 'valid', use_bias = False)(x)
    out_cls = Reshape((c_dim, ))(out_cls)
    return Model(input_image, [out_src, out_cls])

In [10]:
# borrowed the code from https://www.programcreek.com/python/example/90401/tensorflow.divide
def classification_loss(Y_true, Y_pred) :
        return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_true, logits=Y_pred))

    
def wasserstein_loss(Y_true, Y_pred):
    return K.mean(Y_true*Y_pred)

#borrowed code from https://danijar.com/building-variational-auto-encoders-in-tensorflow/
def reconstruction_loss(Y_true, Y_pred):
    return K.mean(K.abs(Y_true - Y_pred))

#borrowed code from https://github.com/hoangthang1607/StarGAN-Keras
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
    """
    Computes gradient penalty based on prediction and weighted real / fake samples
    """
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring
    gradients_sqr = K.square(gradients)
    # summing over the rows
    gradients_sqr_sum = K.sum(gradients_sqr, axis=numpy.arange(1, len(gradients_sqr.shape)))
    # and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

In [11]:
beta_1 = 0.5
beta_2 = 0.999
lambda_cls = 1
lambda_rec = 10
image_size = 128
batch_size = 4
c_dim = 5
d_lr = 0.0001
g_lr = 0.0001
lambda_gp = 10
lambda_rec = 10
selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
image_size=128


In [12]:
def build_and_train(previous_model = False):
    
    # declaration of variables
    beta_1 = 0.5
    beta_2 = 0.999
    lambda_cls = 1
    lambda_rec = 10
    image_size = 128
    batch_size = 4
    c_dim = 5
    d_lr = 0.0001
    g_lr = 0.0001
    lambda_gp = 10
    lambda_rec = 10
    previous_iteration = 50500
    selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
    model_save_dir = 'model_save/'
    model_load_dir = 'model_previous/'
    gen = generator()
    dis = discriminator()
    time_calculate = 100
    mode= 'train'
    # Only updating the weights of the discriminator and making the genrator trainable False
    gen.trainable = False
    
    # using real image to get the discriminator output
    x_real = Input(shape = (image_size, image_size, 3))
    out_src_real, out_class_real = dis(x_real)
    
    
    # using fake image to get the output of the discriminator
    target_label = Input(shape = (5,))
    x_fake = gen([x_real, target_label])
    out_src_fake, out_cls_fake = dis(x_fake)
    
    # where ^x is sampled uniformly along a straight line between a pair of a real and a generated images.
    random_average = RandomWeightedAverage()
    random_average.define_batch_size(batch_size)
    x_bet_line = random_average([x_real, x_fake])
    out_src, _ = dis(x_bet_line)
    
    # Calculating the gradient penalty loss using the partial parameters to the method
    partial_gp_loss = partial(gradient_penalty_loss, averaged_samples = x_bet_line)
    partial_gp_loss.__name__ = 'gradient_penalty'
    
    # Defining training model for Discriminator
    train_dis = Model([x_real, target_label],[out_src_real, out_class_real, out_src_fake, out_src])
    
    # load the previous model, load the previous weights for continuation of the training
    if previous_model:
        train_dis.load_weights(os.path.join(model_load_dir, 'train_dis_weights.hdf5'))
    # Loss for discriminator
    train_dis.compile(loss = [wasserstein_loss, classification_loss, wasserstein_loss, partial_gp_loss], 
                    optimizer=Adam(lr = d_lr, beta_1 = beta_1, beta_2 = beta_2), loss_weights=[1,lambda_cls,1,lambda_gp])
                    
    # Update generator only
    gen.trainable = True
    dis.trainable = False

    # All inputs
    real_x = Input(shape = (image_size, image_size, 3))
    original_label = Input(shape = (c_dim, ))
    target_label = Input(shape = (c_dim, ))

    # Generate the image using the generator which is the fake image.
    fake_x = gen([real_x, target_label])
    fake_out_src, fake_out_cls = dis(fake_x)

    # Target-to-original domain.
    x_reconst = gen([fake_x, original_label])

    # Define traning model G
    train_gen = Model([real_x, original_label, target_label], [fake_out_src, fake_out_cls, x_reconst])
    
    if previous_model:
        train_gen.load_weights(os.path.join(model_load_dir, 'train_gen_weights.hdf5'))    
    
    # Setup loss for train_G
    train_gen.compile(loss = [wasserstein_loss, classification_loss, reconstruction_loss], 
                         optimizer=Adam(lr = g_lr, beta_1 = beta_1, beta_2 = beta_2), loss_weights = [1, lambda_cls, lambda_rec])

    
    num_iters = 200000
    num_iters_decay = 100000
    n_critic = 5
    log_step = 10
    sample_step = 1000
    model_save_step = 10000 
    lr_update_step = 1000
    model_save_dir = 'model_save/'
    
    # loading the previously saved model of the generator and discriminator
    if previous_model:
        gen.load_weights(os.path.join(model_load_dir, 'G_weights.hdf5'))
        dis.load_weights(os.path.join(model_load_dir, 'D_weights.hdf5')) 

    #saving the loss and statistics in the csv file
    if not os.path.isfile('data.csv'):
        with open('data.csv','w') as newFile:
            newFileWriter = csv.writer(newFile)
            newFileWriter.writerow(['num_iter','loss_real','loss_fake', 'loss_cls', 'loss_gp', 'loss_fake', 'loss_rec', 'loss_cls'])
    
    # Training of the model
    valid = -np.ones((batch_size, 2, 2, 1))
    fake =  np.ones((batch_size, 2, 2, 1))
    dummy = np.zeros((batch_size, 2, 2, 1)) # Dummy gt for gradient penalty
    start = time.time()
    # Exception handling of the code, where in case of any error it will restarted from the point where it failed
    while previous_iteration <= 200000:
        # Loading the training images from the images director inside the celeba directory
        Image_data_class = ImageData(data_dir='celeba', selected_attrs=selected_attrs)
        # cropping and resizing the image
        Image_data_class.preprocess()
        # loading the image into the iterator
        data_iter = get_loader(Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label, 
                                   image_size=image_size, batch_size=batch_size, mode=mode)
        #starting of the training for 200,000 iterations
        try:
            for epoch in range(previous_iteration, num_iters):
                # this is to estimate the time required to complete all the iterations
                done = time.time()
                # unpacking the ob;jects from the iterator
                imgs, original_labels, target_labels, fix_labels, _ = next(data_iter)

                # Setting learning rate, which is the linear decay
                if epoch > (num_iters - num_iters_decay):
                    K.set_value(train_dis.optimizer.lr, d_lr*(num_iters - epoch)/(num_iters - num_iters_decay))
                    K.set_value(train_gen.optimizer.lr, g_lr*(num_iters - epoch)/(num_iters - num_iters_decay))

                # Training the Discriminator        
                D_loss = train_dis.train_on_batch(x = [imgs, target_labels], y = [valid, original_labels, fake, dummy])

                # calculating the loss of the generator
                if (epoch + 1) % n_critic == 0:
                    G_loss = train_gen.train_on_batch(x = [imgs, original_labels, target_labels], y = [valid, target_labels, imgs])
                # printing the loss and statis and putting all the statistics in the csv file
                if (epoch + 1) % log_step == 0:
                    print(f"Iteration: [{epoch + 1}/{num_iters}]")
                    print(f"\tD/loss_real = [{D_loss[1]:.4f}], D/loss_fake = [{D_loss[3]:.4f}], D/loss_cls =  [{D_loss[2]:.4f}], D/loss_gp = [{D_loss[4]:.4f}]")
                    print(f"\tG/loss_fake = [{G_loss[1]:.4f}], G/loss_rec = [{G_loss[3]:.4f}], G/loss_cls = [{G_loss[2]:.4f}]") 
                    with open('data.csv', 'a') as newFile:
                        newFileWriter = csv.writer(newFile)
                        newFileWriter.writerow([epoch+1, D_loss[1], D_loss[3], D_loss[2], D_loss[4], G_loss[1], G_loss[3], G_loss[2]])
                # Saving the model which can be utilized later to continue the training
                if (epoch + 1) % model_save_step == 0:
                    data_iter = get_loader(Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label, 
                                   image_size=image_size, batch_size=batch_size, mode=mode)        
                    n_batches = int(sample_step / batch_size)
                    total_samples = n_batches * batch_size
                    gen.save_weights(os.path.join(model_save_dir, 'G_weights.hdf5'))
                    dis.save_weights(os.path.join(model_save_dir, 'D_weights.hdf5'))
                    train_dis.save_weights(os.path.join(model_save_dir, 'train_dis_weights.hdf5'))
                    train_gen.save_weights(os.path.join(model_save_dir, 'train_gen_weights.hdf5'))
                # Estimating the time completion
                if (epoch + 1 ) % time_calculate == 0:
                    elapsed = done - start
                    print("Time per {} iteration is {}".format(time_calculate,elapsed))
                    print("Total Time Remaining is {} minutes".format((((num_iters - epoch+1)/time_calculate)*elapsed)/60) )
                    start = time.time()
        except:
            # to restart the training where it ended
            previous_iteration = epoch + 1
            for filename in os.listdir(model_save_dir):
                shutil.copy( os.path.join(model_save_dir, filename), model_load_dir)

In [13]:
build_and_train(previous_model = True)

  del sys.path[0]
  import sys
  # This is added back by InteractiveShellApp.init_path()


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


OSError: Unable to open file (truncated file: eof = 120143872, sblock->base_addr = 0, stored_eof = 212851720)

In [14]:
def test():
    model_save_dir = 'model_save'
    sample_step = 1000
    # mode is to test because we will not flip the images with probability 0.5
    mode = 'test'
    image_size = 128
    batch_size = 4
    result_dir = 'model_result'
    selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
    Image_data_class = ImageData(data_dir='celeba', selected_attrs=selected_attrs)
    Image_data_class.preprocess()
    data_iter = get_loader(Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label, 
                               image_size=image_size, batch_size=batch_size, mode=mode)
    G_weights_dir = os.path.join(model_save_dir, 'G_weights.hdf5')
    G = generator()
    G.load_weights(G_weights_dir)

    # Data iterator
    data_iter = get_loader(Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label, 
                           image_size=image_size, batch_size=batch_size, mode=mode)        
    n_batches = int(sample_step / batch_size)
    total_samples = n_batches * batch_size

    for i in range(n_batches):
        imgs, orig_labels, target_labels, fix_labels, names = next(data_iter)
        for j in range(batch_size):
            preds = G.predict([np.repeat(np.expand_dims(imgs[j], axis = 0), len(selected_attrs), axis = 0), fix_labels[j]])
            for k in range(len(selected_attrs)):                    
                Image.fromarray((preds[k]*127.5 + 127.5).astype(np.uint8)).save(os.path.join(result_dir, names[j].split(os.path.sep)[-1].split('.')[0] + f'_{k + 1}.png'))