In [None]:
import os

import keras.backend as K
from keras.models import Model, Sequential
from keras.layers import Conv2D, MaxPool2D, Flatten, Dense, Input, Subtract, Lambda
from keras.optimizers import Adam, SGD
from keras.regularizers import l2
import keras.backend as K
import numpy.random as rng

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np


import import_ipynb
import lichensloader 

In [None]:

class SiameseNetwork:
    """
    Construction of the network for training.
    
    Attributes:
        input_shape: image size 
        model : current siamese model 
        learning_rate: SGD learning rate 
        summary_writer: tensorflow writer to store the logs 
    """
    
    def __init__(self, dataset_path, learning_rate, batch_size, use_augmentation, 
                 learning_rate_multipliers, l2_regularization_penalization):
        
        self.input_shape = (400,400,3)
        self.model = []
        self.learning_rate = learning_rate 
        #self.summary_writer = tf.summary.FileWriter(tensorboard_log_path)
        
        self._construct_siamese_architecture(learning_rate_multipliers, l2_regularization_penalization)
        
        self.lichen_loader = lichensloader.Lichensloader(dataset_path=dataset_path, use_augmentation=use_augmentation, batch_size=batch_size)
    
    def contrastive_loss(self,y_true, y_pred):
        '''Contrastive loss from Hadsell-et-al.'06
        http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
        '''
        margin = 1
        sqaure_pred = K.square(y_pred)
        margin_square = K.square(K.maximum(margin - y_pred, 0))
        return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)

    def _construct_siamese_architecture(self, learning_rate_multipliers, l2_regularization_penalization):
        """
        Costruct the siamese netwrok architecture and stores it in the class
        
        Arguments:
            learning_rate_multipliers
            l2_regularization_penalization
        """
    
        convolutional_net = Sequential()
        convolutional_net.add(Conv2D(filters = 64, kernel_size = (10,10),activation = 'relu',
                              input_shape = self.input_shape, kernel_regularizer = l2(l2_regularization_penalization['Conv1']),
                              name = 'Conv1'))
        
        convolutional_net.add(MaxPool2D())
        
        
        
        convolutional_net.add(Conv2D(filters = 128, kernel_size = (7,7),activation = 'relu',
                              input_shape = self.input_shape, kernel_regularizer = l2(l2_regularization_penalization['Conv2']),
                              name = 'Conv2'))
        
        convolutional_net.add(MaxPool2D())
        
        
        
        convolutional_net.add(Conv2D(filters = 128, kernel_size = (4,4),activation = 'relu',
                              input_shape = self.input_shape, kernel_regularizer = l2(l2_regularization_penalization['Conv3']),
                              name = 'Conv3'))
        
        convolutional_net.add(MaxPool2D())
    
    
        
        
        convolutional_net.add(Conv2D(filters = 256, kernel_size = (4,4),activation = 'relu',
                              input_shape = self.input_shape, kernel_regularizer = l2(l2_regularization_penalization['Conv4']),
                              name = 'Conv4'))
        
        convolutional_net.add(MaxPool2D())
        
        
        convolutional_net.add(Flatten())
        
        
        convolutional_net.add(Dense(units = 4096, activation = 'sigmoid',kernel_regularizer = l2(
                                l2_regularization_penalization['Dense1']),name = 'Dense1'))
        
        
        input_image_1 = Input(self.input_shape)
        input_image_2 = Input(self.input_shape)
        
        
        encoded_image_1 = convolutional_net(input_image_1)
        encoded_image_2 = convolutional_net(input_image_2)
        
        l1_distance_layer = Lambda( lambda tensors : K.abs(tensors[0] - tensors[1]))
        
        l1_distance = l1_distance_layer([encoded_image_1, encoded_image_2])
        
        # Prediction layers 
        prediction = Dense(units = 1 , activation = 'sigmoid')(l1_distance)
        
        self.model = Model(inputs = [input_image_1,input_image_2],outputs = prediction)
        
        
        optimizer = Adam(lr = 0.00006)
        
        self.model.compile(loss = self.contrastive_loss, metrics = ['binary_accuracy'],
                           optimizer = optimizer)
        
        

    def train_siamese_network(self, number_of_iterations, support_set_size,
                              final_momentum, momentum_slope, evaluate_each,
                              model_name):
        """ Train the Siamese net
        This is the main function for training the siamese net. 
        In each every evaluate_each train iterations we evaluate one-shot tasks in 
        validation and evaluation set. We also write to the log file.
        Arguments:
            number_of_iterations: maximum number of iterations to train.
            support_set_size: number of characters to use in the support set
                in one-shot tasks.
            final_momentum: mu_j in the paper. Each layer starts at 0.5 momentum
                but evolves linearly to mu_j
            momentum_slope: slope of the momentum evolution. In the paper we are
                only told that this momentum evolves linearly. Because of that I 
                defined a slope to be passed to the training.
            evaluate each: number of iterations defined to evaluate the one-shot
                tasks.
            model_name: save_name of the model
        Returns: 
            Evaluation Accuracy
        """

        # First of all let's divide randomly the 30 train alphabets in train
        # and validation with 24 for training and 6 for validation
        self.lichen_loader.split_train_datasets()

        # Variables that will store 100 iterations losses and accuracies
        # after evaluate_each iterations these will be passed to tensorboard logs
        train_losses = np.zeros(shape=(evaluate_each))
        train_accuracies = np.zeros(shape=(evaluate_each))
        count = 0
        earrly_stop = 0
        # Stop criteria variables
        best_validation_accuracy = 0.0
        best_accuracy_iteration = 0
        validation_accuracy = 0.0


        # Train loop
        for iteration in range(number_of_iterations):
            print("----> ",iteration )

            # train set
            images, labels = self.lichen_loader.create_pairs_for_batch()
            train_loss, train_accuracy = self.model.train_on_batch(
                images, labels)

            # Decay learning rate 1 % per 500 iterations (in the paper the decay is
            # 1% per epoch). Also update linearly the momentum (starting from 0.5 to 1)
            if (iteration + 1) % 500 == 0:
                K.set_value(self.model.optimizer.lr, K.get_value(
                    self.model.optimizer.lr) * 0.99)
            if K.get_value(self.model.optimizer.momentum) < final_momentum:
                K.set_value(self.model.optimizer.momentum, K.get_value(
                    self.model.optimizer.momentum) + momentum_slope)

            train_losses[count] = train_loss
            train_accuracies[count] = train_accuracy

            # validation set
            count += 1
            print('Iteration %d/%d: Train loss: %f, Train Accuracy: %f, lr = %f' %
                  (iteration + 1, number_of_iterations, train_loss, train_accuracy, K.get_value(
                      self.model.optimizer.lr)))

            # Each 100 iterations perform a one_shot_task and write to tensorboard the
            # stored losses and accuracies
            if (iteration + 1) % evaluate_each == 0:
                print("******",number_of_runs_per_alphabet)
                number_of_runs_per_alphabet = 40
                # use a support set size equal to the number of character in the alphabet
                validation_accuracy = self.lichen_loader.one_shot_test(
                    self.model, support_set_size, number_of_runs_per_alphabet, is_validation=True)

                #self._write_logs_to_tensorboard(
                #    iteration, train_losses, train_accuracies,
                #   validation_accuracy, evaluate_each)
                count = 0

                # Some hyperparameters lead to 100%, although the output is almost the same in 
                # all images. 
                if (validation_accuracy == 1.0 and train_accuracy == 0.5):
                    print('Early Stopping: Gradient Explosion')
                    print('Validation Accuracy = ' +
                          str(best_validation_accuracy))
                    return 0
                elif train_accuracy == 0.0:
                    return 0
                else:
                    # Save the model
                    if validation_accuracy > best_validation_accuracy:
                        best_validation_accuracy = validation_accuracy
                        best_accuracy_iteration = iteration
                        
                        model_json = self.model.to_json()

                        if not os.path.exists('./models'):
                            os.makedirs('./models')
                        with open('models/' + model_name + '.json', "w") as json_file:
                            json_file.write(model_json)
                        self.model.save_weights('models/' + model_name + '.h5')

            # If accuracy does not improve for 10000 batches stop the training
            if iteration - best_accuracy_iteration > 10000:
                print(
                    'Early Stopping: validation accuracy did not increase for 10000 iterations')
                print('Best Validation Accuracy = ' +
                      str(best_validation_accuracy))
                print('Validation Accuracy = ' + str(best_validation_accuracy))
                break

        print('Trained Ended!')
        return best_validation_accuracy
        
        
        
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
        
        