In [6]:
import os, sys
import pickle
from matplotlib import pyplot as plt
# import python library
sys.path.append(os.path.join(os.getcwd().split(os.environ.get('USER'))[0],os.environ.get('USER'), 'wdml', 'py'))

from dataset import Dataset
from database import Database
from sample import Sample

class NeuralNetwork(Database):
    
    __scaler = None
    __model = None
    __history = None
    
    def __init__(self, dataset_location, database_location, site, transforms, transforms_params):
        self.__transforms = transforms
        self.__transforms_params = transforms_params
        super().__init__(dataset_location, database_location, site)
    
    ################
    """SAVE AND LOAD MODEL"""
    def save_scaler(self, scaler):
        """Save the scaler used to scale the training data
        Param
            scaler
        """
        print('Saving scaler ....',end='')
        standard_scaler = {"scaler": scaler}
        scaler_path = os.path.join(self.get_database_location(),self.get_site(),'models',
                                   str(self.__transforms)+'_'+str(self.__transforms_params)+'_scaler.pickle')
        pickle.dump( standard_scaler, open(scaler_path, "wb" ) )
        print('Done.')

    def save_history(self, history):
        """Save the history from the training
        Param
            history
        """
        print('Saving history ....',end='')
        history = {"history": history.history}
        history_path = os.path.join(self.get_database_location(),self.get_site(),'models',
                                   str(self.__transforms)+'_'+str(self.__transforms_params)+'_history.pickle')
        pickle.dump( history, open(history_path, "wb" ) )
        print('Done.')
        
    def save_model(self, model):
        """Save the trained model
        Param
            model
        """
        print('Saving history ....',end='')
        model_path = os.path.join(self.get_database_location(),self.get_site(),'models',
                                   str(self.__transforms)+'_'+str(self.__transforms_params)+'_model.h5')
        model.save(model_path)
        print('Done.')
        
    def load_model
    
    #####################
    """MODEL DETECTION"""
    
    
    #########################
    """MODEL VISUALISATION"""
    def get_conv_layers(self, model):
        """Get number of convolutional layers in the model
        """
        layers = []
        # summarize filter shapes
        for layer in model.layers:
            # check for convolutional layer
            if 'conv' not in layer.name:
                continue
            # get filter weights
            layers.append(layer)
        return layers

    def visualization(self, model,summary=False, plot=False, conv_layer_number=0):
        '''Visualize the CNN model'''
        if summary:
            print(model.summary())
        if plot:
            path = os.path.join(database_loc,site,'models','model.png')
            plot_model(model, show_shapes=True, show_layer_names=True, to_file=path)
            Image(retina=True, filename=path)
        # get filers and biases
        filters, biases = get_conv_layers(model)[conv_layer_number].get_weights()
        # scale filters
        f_min, f_max = filters.min(), filters.max()
        filters = (filters-f_min)/(f_max-f_min)
        # get number of filters
        n_filters, ix = filters.shape[-1], 1

        for i in range(n_filters):
            # get the filters
            f = filters[:,:,:,i]
            # plot each channel seperately
            for j in range(1):
                # specify subplot and turn of axis
                ax = plt.subplot(1,n_filters, ix)
                ax.set_xticks([])
                ax.set_yticks([])
                # plot filter channel in grayscale
                plt.imshow(f[:, :, j], cmap='jet')
                ix += 1
        plt.show()

    def visualization_fm(self, model,train_X, conv_layer_number=0):
        data = train_X[np.random.randint(len(train_X))]# get filers and biases
        filters, biases = get_conv_layers(model)[conv_layer_number].get_weights()
        # scale filters
        f_min, f_max = filters.min(), filters.max()
        filters = (filters-f_min)/(f_max-f_min)
        # get number of filters
        n_filters, ix = filters.shape[-1], 1
        feature_maps = model.predict([data])
        for i in range(n_filters):
            # get the filters
            f = filters[:,:,:,i]
            # plot each channel seperately
            for j in range(1):
                # specify subplot and turn of axis
                ax = plt.subplot(1,n_filters, ix)
                ax.set_xticks([])
                ax.set_yticks([])
                # plot filter channel in grayscale
                plt.pcolormesh(f[:, :, j], cmap='jet')
                ix += 1
        plt.show()
    
    def show_history(self,history):
        # summarize history for accuracy
        plt.figure()
        plt.plot(history['accuracy'])
        plt.plot(history['val_accuracy'])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.show()
        # summarize history for loss
        plt.figure()
        plt.plot(history['loss'])
        plt.plot(history['val_loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.show()