In [12]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

import os
import sys

from tensorflow import keras

sys.path.append('..')
import model
import utils

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
params = utils.yaml_to_dict(os.path.join('..','config.yml'))
params['data_dir'] = os.path.join('..',params['data_dir'])
params['model_dir'] = os.path.join('..',params['model_dir'])
params['batch_size'] = 10

In [3]:
def make_datagenerator(params, mode):
    
    file_path = mode + '.txt'
    data_path = os.path.join(params['data_dir'],file_path)
    
    df = pd.read_csv(data_path, sep="\t", header=0)
    
    datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
    generator = datagen.flow_from_dataframe(
        dataframe = df,
        x_col = 'images',
        target_size = params['image_shape'],
        batch_size = params['batch_size'],
        class_mode = None)
    
    return generator

In [33]:
def load_model(params):
    #tf.keras.backend.clear_session()
    width, height = params['image_shape']
    inputs = tf.keras.layers.Input(shape=(width, height, 3))
    net = model.ModelArchitecture(num_classes=params['num_classes'])
    x = net(inputs, training=False)
    return net
    

def make_predictions(generator, params):
    
    net = load_model(params)
    
    optimizer = tf.keras.optimizers.Adam(lr=params['learning_rate'])
    net.compile(optimizer=optimizer, loss=params['loss'], metrics=['sparse_categorical_accuracy'])
    net.load_weights(os.path.join(params['model_dir'], 'tf_ckpt'))
    #predictions = net.predict(x=data_to_predict, batch_size=params['batch_size'], verbose=1)
    predictions = net.predict_generator(generator)
    
    return np.argmax(predictions,axis=1)

In [34]:
def plot_results(data_to_predict,predictions, params):
    
    id_label_map = utils.load_id_label_map(params)
    
    num_audios = len(data_to_predict) if len(data_to_predict)>1 else len(data_to_predict)+1
    
    fig, axis = plt.subplots(num_audios,2, figsize=(20,20))
    
    images = data_to_predict[0]
    
    k = 0
    for i in range(len(data_to_predict)):

        axis[i][0].imshow(data_to_predict[k])
        axis[i][0].axis('off')

        pred = id_label_map[predictions[k]]
        axis[i][1].text(0.5, 0.5, pred, size=25, ha="center", va="center",
                bbox=dict(
                    boxstyle="round",
                    ec=(1., 0.5, 0.5),
                    fc=(1., 0.8, 0.8),
                )
               )
        axis[i][1].set_title('Prediction')
        axis[i][1].axis('off')
            
        k += 1
    plt.show()

In [50]:
generator = make_datagenerator(params,mode='test')

Found 30 images.


In [51]:
preds = make_predictions(generator,params)
preds

array([3, 3, 1, 9, 5, 7, 7, 9, 8, 6, 3, 5, 1, 1, 2, 0, 4, 9, 2, 8, 8, 0,
       4, 5, 6, 2, 0, 4, 6, 7], dtype=int64)