# Develop Model Driver
In this notebook we will develop the API that will call our model. We need it to initialise the model and transform the input from the Flask app so that it is in the appropriate format to call the model. We expect the input to be JSON that will have the image encoded as a base64 string. The code below uses the writefile magic to write the contents of the cell to the file driver.py

In [1]:
import logging
from testing_utilities import img_url_to_json

In [2]:
%%writefile driver.py
import base64
import json
import logging
import os
import timeit as t
from io import BytesIO

import numpy as np
import tensorflow as tf
from PIL import Image, ImageOps
from tensorflow.contrib.slim.nets import resnet_v1

_MODEL_FILE = os.getenv('MODEL_FILE', "resnet_v1_152.ckpt")
_LABEL_FILE = os.getenv('LABEL_FILE', "synset.txt")
_NUMBER_RESULTS = 3


def _create_label_lookup(label_path):
    with open(label_path, 'r') as f:
        label_list = [l.rstrip() for l in f]
        
    def _label_lookup(*label_locks):
        return [label_list[l] for l in label_locks]
    
    return _label_lookup


def _load_tf_model(checkpoint_file):
    # Placeholder
    input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
    
    # Load the model
    sess = tf.Session()
    arg_scope = resnet_v1.resnet_arg_scope()
    with tf.contrib.slim.arg_scope(arg_scope):
        logits, _ = resnet_v1.resnet_v1_152(input_tensor, num_classes=1000, is_training=False, reuse=tf.AUTO_REUSE)
    probabilities = tf.nn.softmax(logits)
    
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_file)
    
    def predict_for(image):
        pred, pred_proba = sess.run([logits,probabilities], feed_dict={input_tensor: image})
        return pred_proba
    
    return predict_for


def _base64img_to_numpy(base64_img_string):
    if base64_img_string.startswith('b\''):
        base64_img_string = base64_img_string[2:-1]
    base64Img = base64_img_string.encode('utf-8')

    # Preprocess the input data 
    startPreprocess = t.default_timer()
    decoded_img = base64.b64decode(base64Img)
    img_buffer = BytesIO(decoded_img)

    # Load image with PIL (RGB)
    pil_img = Image.open(img_buffer).convert('RGB')
    pil_img = ImageOps.fit(pil_img, (224, 224), Image.ANTIALIAS)
    return np.array(pil_img, dtype=np.float32)


def create_scoring_func(model_path=_MODEL_FILE, label_path=_LABEL_FILE):
    logger = logging.getLogger("model_driver")
    
    start = t.default_timer()
    labels_for = _create_label_lookup(label_path)
    predict_for = _load_tf_model(model_path)
    end = t.default_timer()

    loadTimeMsg = "Model loading time: {0} ms".format(round((end-start)*1000, 2))
    logger.info(loadTimeMsg)
    
    def call_model(image_array, number_results=_NUMBER_RESULTS):
        pred_proba = predict_for(image_array).squeeze()
        selected_results = np.flip(np.argsort(pred_proba), 0)[:number_results]
        labels = labels_for(*selected_results)
        return list(zip(labels, pred_proba[selected_results].astype(np.float64)))
    return call_model


def get_model_api():
    logger = logging.getLogger("model_driver")
    scoring_func = create_scoring_func()
    
    def process_and_score(images_dict, number_results=_NUMBER_RESULTS):
        start = t.default_timer()

        results = {}
        for key, base64_img_string in images_dict.items():
            rgb_image = _base64img_to_numpy(base64_img_string)
            batch_image = np.expand_dims(rgb_image, 0)
            results[key]=scoring_func(batch_image, number_results=_NUMBER_RESULTS)
        
        end = t.default_timer()

        logger.info("Predictions: {0}".format(results))
        logger.info("Predictions took {0} ms".format(round((end-start)*1000, 2)))
        return (results, 'Computed in {0} ms'.format(round((end-start)*1000, 2)))
    return process_and_score

def version():
    return tf.__version__
    

Overwriting driver.py


In [3]:
logging.basicConfig(level=logging.DEBUG)

We run the file driver.py which will be everything into the context of the notebook.

In [4]:
%run driver.py

We will use the same Lynx image we used ealier to check that our driver works as expected.

In [5]:
IMAGEURL = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Lynx_lynx_poing.jpg/220px-Lynx_lynx_poing.jpg"

In [6]:
jsonimg = img_url_to_json(IMAGEURL)

In [7]:
json_lod= json.loads(jsonimg)

In [8]:
predict_for = get_model_api()

INFO:tensorflow:Restoring parameters from resnet_v1_152.ckpt
INFO:model_driver:Model loading time: 17208.69 ms


In [9]:
output = predict_for(json_lod['input'])

DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'iCCP' 41 292
DEBUG:PIL.PngImagePlugin:iCCP profile name b'ICC Profile'
DEBUG:PIL.PngImagePlugin:Compression method 0
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 345 65536
INFO:model_driver:Predictions: {'image': [('n02127052 lynx, catamount', 0.9974517226219177), ('n02128385 leopard, Panthera pardus', 0.0015077503630891442), ('n02128757 snow leopard, ounce, Panthera uncia', 0.0005164773901924491)]}
INFO:model_driver:Predictions took 1916.85 ms


The output of our prediction function is JSON that will be returned to our Flask app. It looks like our model predicted Lynx with over 99% probability.

In [10]:
json.dumps(output)

'[{"image": [["n02127052 lynx, catamount", 0.9974517226219177], ["n02128385 leopard, Panthera pardus", 0.0015077503630891442], ["n02128757 snow leopard, ounce, Panthera uncia", 0.0005164773901924491]]}, "Computed in 1916.85 ms"]'

We can move onto [building our docker image](02_BuildImage.ipynb)