In [None]:
import sys
import os
import shutil
sys.path.insert(0,'../..')

from AutoGAN import GAN
from AutoGAN.schemes.CycleGAN_TrainingScheme import CycleWGAN_TrainingScheme, CycleGAN_TrainingScheme
from AutoGAN.schemes.SimGAN_TrainingScheme import SimGAN_TrainingScheme

import keras

from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import Conv1D, Conv2D, MaxPooling2D, GlobalMaxPooling2D, GlobalAveragePooling2D
from keras.layers import UpSampling2D, LeakyReLU, Lambda, Add, Multiply, Activation, Conv2DTranspose
from keras.layers import Cropping2D, ZeroPadding2D, Flatten, Subtract, Input, add, multiply
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam 

import matplotlib.pyplot as plt
import numpy as np
import scipy
from skimage.transform import resize
import glob
from random import shuffle

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "1"
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))

def build_generator(size):
    """
    The refiner network, Rθ, is a residual network (ResNet). It modifies the synthetic image on a pixel level, rather
    than holistically modifying the image content, preserving the global structure and annotations.
    :param input_image_tensor: Input tensor that corresponds to a synthetic image.
    :return: Output tensor that corresponds to a refined synthetic image.
    """
    def resnet_block(input_features, nb_features=64, nb_kernel_rows=3, nb_kernel_cols=3):
        """
        A ResNet block with two `nb_kernel_rows` x `nb_kernel_cols` convolutional layers,
        each with `nb_features` feature maps.
        See Figure 6 in https://arxiv.org/pdf/1612.07828v1.pdf.
        :param input_features: Input tensor to ResNet block.
        :return: Output tensor from ResNet block.
        """
        y = Conv2D(nb_features, (nb_kernel_rows,nb_kernel_cols), padding='same')(input_features)
        y = Activation('relu')(y)
        y = Conv2D(nb_features, (nb_kernel_rows,nb_kernel_cols), padding='same')(y)

        y = add([input_features, y])
        return Activation('relu')(y)

    # an input image of size w × h is convolved with 3 × 3 filters that output 64 feature maps
    img = Input(shape=size)
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(img)

    # the output is passed through 4 ResNet blocks
    for _ in range(4):
        x = resnet_block(x)

    # the output of the last ResNet block is passed to a 1 × 1 convolutional layer producing 1 feature map
    # corresponding to the refined synthetic image
    x = Conv2D(1, (1, 1), padding='same', activation='tanh')(x)
    return Model(img, x)


def build_discriminator(size):
    """
    The discriminator network, Dφ, contains 5 convolution layers and 2 max-pooling layers.
    :param input_image_tensor: Input tensor corresponding to an image, either real or refined.
    :return: Output tensor that corresponds to the probability of whether an image is real or refined.
    """
    img = Input(shape=size)
    x = Conv2D(96, (3, 3), padding='same', strides=2, activation='relu')(img)
    x = Conv2D(64, (3, 3), padding='same', strides=2, activation='relu')(x)
    x = MaxPooling2D(pool_size=(3, 3), padding='same', strides=(1, 1))(x)
    x = Conv2D(32, (3, 3), padding='same', strides=1, activation='relu')(x)
    x = Conv2D(32, (1, 1), padding='same', strides=1, activation='relu')(x)
    x = Conv2D(1, (1, 1), padding='same', strides=1, activation='sigmoid')(x)
    # here one feature map corresponds to `is_real`
    return Model(img, x)


def butchered_mp_normalized_matlab_helper(mat_file_path):
    """
    Normalized data is provided in matlab files in MPIIGaze Dataset and these are tricky to load with Python.
    This function was made with guessing and checking. Very frustrating.
    :param mat_file_path: Full path to MPIIGaze Dataset matlab file.
    :return: np array of images.
    """
    import glob
    import os
    import uuid

    import numpy as np
    from PIL import Image
    import scipy.io as sio

    
    x = sio.loadmat(mat_file_path)
    y = x.get('data')
    z = y[0, 0]

    left_imgs = z['left']['image'][0, 0]
    right_imgs = z['right']['image'][0, 0]

    for img in np.concatenate((left_imgs, right_imgs)):
        Image.fromarray(img).save(os.path.join('./RealGaze_data', '{}.png'.format(uuid.uuid4())))
    return


def plot_batch(image_batch, figure_path, label_batch=None, vmin=0, vmax=255, scale=True):
    """
    Plots a batch of images and their corresponding label(s)/annotations, saving the plot to disc.
    :param image_batch: Batch of images to be plotted.
    :param figure_path: Full path of the filename the plot will be saved as.
    :param label_batch: Batch of labels corresponding to `image_batch`.
       Labels will be displayed along w/ their corresponding image.
    """
    if label_batch is not None:
        assert len(image_batch) == len(label_batch), 'Their must be a label for each image to be plotted.'

    batch_size = len(image_batch)
    assert batch_size >= 1

    assert isinstance(image_batch, np.ndarray), 'image_batch must be an np array.'

    # for gray scale images if image_batch.shape == (img_height, img_width, 1) plt requires this to be reshaped
    if image_batch.shape[-1] == 1:
        image_batch = np.reshape(image_batch, newshape=image_batch.shape[:-1])

    # plot images in rows and columns
    # `+ 2` prevents plt.subplots from throwing: `TypeError: 'AxesSubplot' object does not support indexing` when batch_size < 10
    nb_rows = batch_size // 3
    nb_columns = 3
    
    import matplotlib
    matplotlib.rcParams.update({'axes.titlesize': 10})
    _, axs = plt.subplots(nb_rows, nb_columns, figsize=(5, 10))
    cnt = 0 
    for i in range(nb_columns):
        for j in range(nb_rows):
            try:
                axs[j,i].imshow((image_batch[cnt]+1.)/2., cmap='gray')
                if label_batch is not None:
                    if j == 0:
                        axs[j, i].set_title(label_batch[cnt])
                    cnt += 1
                axs[j,i].axis('off')
            except IndexError:
                break

    plt.savefig(os.path.join(figure_path))
    plt.close()

def load_data_h5():
    import h5py
    with h5py.File('./gaze.h5','r') as t_file:
        #print(list(t_file.keys()))
        assert 'image' in t_file, "Images are missing"
        assert 'look_vec' in t_file, "Look vector is missing"
        assert 'path' in t_file, "Paths are missing"
        #print('Synthetic images found:',len(t_file['image']))
        for _, (ikey, ival) in zip(range(1), t_file['image'].items()):
            #print('image',ikey,'shape:',ival.shape)
            img_height, img_width = ival.shape
            img_channels = 1
        syn_image_stack = np.stack([np.expand_dims(a,-1) for a in t_file['image'].values()],0)
    print(syn_image_stack.shape)
    
    with h5py.File('./real_gaze.h5','r') as t_file:
        #print(list(t_file.keys()))
        assert 'image' in t_file, "Images are missing"
        #print('Real Images found:',len(t_file['image']))
        for _, (ikey, ival) in zip(range(1), t_file['image'].items()):
            #print('image',ikey,'shape:',ival.shape)
            img_height, img_width = ival.shape
            img_channels = 1
        real_image_stack = np.stack([np.expand_dims(a,-1) for a in t_file['image'].values()],0)
    print(real_image_stack.shape)
    A_list, B_list = [i for i in range(syn_image_stack.shape[0])], [i for i in range(real_image_stack.shape[0])]
    shuffle(A_list)
    shuffle(B_list)
    A = syn_image_stack[A_list]
    B = real_image_stack[B_list]
    return (2. * A/255.) - 1., (2. * B/255.) - 1.

class save_images(keras.callbacks.Callback):
    def __init__(self, model, A, B, freq, dataset):
        super(save_images, self).__init__()
        try:
            shutil.rmtree('images/%s'%dataset)
        except:
            pass        
        try:
            os.makedirs('images/%s'%dataset)
        except:
            pass
        self.full_model = model
        self.A = A
        self.B = B
        self.epoch = 0
        self.freq = freq
        self.dataset = dataset
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch
        #print('started epoch %d' % epoch)
    
    def on_batch_end(self, batch, logs=None):
        if batch % self.freq == 0:
            preds = self.full_model.generator_model().predict_on_batch(self.A)
            gen_imgs = np.array([self.A[i, :, :, 0] for i in range(8)] + 
                                [preds[i, :, :, 0] for i in range(8)] + 
                                [self.B[i, :, :, 0] for i in range(8)])
            gen_labels = ["Synthetic"] * 8 + ["Refined"] *8 + ["Real"] * 8
            plot_batch(gen_imgs, "images/%s/sim_%d_%d.png" % (self.dataset, self.epoch, batch), gen_labels)
            #print('\n sampled data at epoch %d , batch %d' % (self.epoch, batch))
            #print(logs)
            #print('\n')
    
    def on_train_end(self, logs=None):
        for i in range(0, A.shape[0]-8, 8):
            preds = self.full_model.generator_model().predict_on_batch(self.A)
            gen_imgs = np.concatenate([self.A[i:i+8, :, :, 0] ,preds[i:i+8, :, :, 0], self.B[i:i+8, :, :, 0]])
            gen_labels = ["Synthetic"] * 8 + ["Refined"] * 8 + ["Real"] * 8
            plot_batch(gen_imgs, "images/%s/sim_final_%d.png" % (self.dataset, i), gen_labels)

def simgan(image_A, image_B):
    model = GAN(generator=build_generator(image_A.shape), 
                discriminator=build_discriminator(image_A.shape))
    optimizer = keras.optimizers.SGD(lr=1e-3)
    discriminator_kwargs = {'loss':'binary_crossentropy', 'metrics':['accuracy'], 'optimizer': optimizer}
    generator_kwargs = {'generator_loss':'mae', 'optimizer': optimizer,
                        'discriminator_loss':'binary_crossentropy', #'generator_metrics':['mae'],
                        'discriminator_loss_weight':1, 'generator_loss_weight':1}
    
    model.compile(training_scheme=SimGAN_TrainingScheme(100 * 512),
                  generator_kwargs=generator_kwargs, discriminator_kwargs=discriminator_kwargs)
    return model

      
            
class pretrain_model(keras.callbacks.Callback):
    def __init__(self, my_model, x, y, epochs, batch_size, loss, metrics, optimizer):        
        self.my_model = my_model
        self.x, self.y = x, y
        self.batch_size = batch_size
        self.epochs = epochs
        self.loss = loss
        self.metrics = metrics
        self.optimizer = optimizer
    def on_train_begin(self, logs=None):
        self.my_model.compile(loss=self.loss, metrics=self.metrics, optimizer=self.optimizer)
        self.my_model.fit(self.x, self.y, epochs=self.epochs, batch_size=self.batch_size, verbose=1, shuffle=True, validation_split=0.2)



In [None]:
A, B = load_data_h5()
A_test, B_test = A[-1000:], B[-1000:]
A, B = A[:-1000], B[:-1000]

model = simgan(A[0], B[0])
#model.summary()
%matplotlib inline
out_shape = list(model.discriminator_model().output_shape[1:])

pre_train_gen = pretrain_model(model.generator_model(), A[:4096], A[:4096], 5, 64, 'mae', ['mae'], keras.optimizers.Adam(lr=0.0001))
pre_train_dis = pretrain_model(model.discriminator_model(), np.concatenate([A[:4096], B[:4096]]),
                               np.concatenate([np.zeros([4096]+out_shape), np.ones([4096]+out_shape)]), 
                               5, 32, 'binary_crossentropy', ['accuracy'], keras.optimizers.Adam(lr=0.0001))
model.fit(x=A, y=B, epochs=10, steps_per_epoch=1000, batch_size=512,
          generator_callbacks=[pre_train_gen, pre_train_dis, save_images(model, A_test, B_test, 100,'simeyes')], verbose=1)
