# Develop Model Driver

In this notebook, we register the model trained earlier and develop the model driver that will be used in the web service creation.

In [None]:
%matplotlib inline

import azureml.core
from azureml.core import Workspace
from azureml.core.model import Model
import tensorflow as tf

import sys
import os
from object_detection.utils import visualization_utils as vis_util
from object_detection.utils import label_map_util
from matplotlib import pyplot as plt

## Register Model

In [None]:
# Check core SDK version number
print("SDK version:", azureml.core.VERSION)

In [None]:
ws = Workspace.from_config()
# print(ws.name, ws.resource_group, ws.location, ws.subscription_id, sep = '\n')

In [None]:
for m in Model.list(workspace=ws):
    print("Name:", m.name,"\tVersion:", m.version, "\tDescription:", m.description, m.tags)

In [None]:
print(model.name, model.description, model.version, sep = '\t')

In [None]:
from dotenv import set_key, get_key, find_dotenv

In [None]:
env_path = find_dotenv(raise_error_if_not_found=True)

In [None]:
set_key(env_path, "model_name", model.name)

## Develop Model Driver

Here we save the model driver as a script.

In [None]:
%%writefile detect.py

from azureml.core.model import Model
from azureml.contrib.services.aml_request import rawhttp
from PIL import Image
import numpy as np
import tensorflow as tf
import json

def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)

def run_inference_for_single_image(image, sess, tensor_dict, image_tensor):
    # Run inference
    output_dict = sess.run(tensor_dict,
                           feed_dict={image_tensor: np.expand_dims(image, 0)})

    # all outputs are float32 numpy arrays, so convert types as appropriate
    output_dict['num_detections'] = int(output_dict['num_detections'][0])
    output_dict['detection_classes'] = output_dict[
        'detection_classes'][0].astype(np.uint8).tolist()
    output_dict['detection_boxes'] = output_dict['detection_boxes'][0].tolist()
    output_dict['detection_scores'] = output_dict['detection_scores'][0].tolist()
    return output_dict

def init():   
    global sess, tensor_dict, image_tensor
    
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        model_path =  Model.get_model_path("maskrcnn_outofstock_model")
        with tf.gfile.GFile(model_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
            sess = tf.Session()
            # Get handles to input and output tensors
            ops = tf.get_default_graph().get_operations()
            all_tensor_names = {output.name for op in ops for output in op.outputs}
            tensor_dict = {}
            for key in [
                'num_detections', 'detection_boxes', 'detection_scores',
                'detection_classes', 'detection_masks'
            ]:
                tensor_name = key + ':0'
                if tensor_name in all_tensor_names:
                    tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
                        tensor_name)
            image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
                       
@rawhttp
def run(request):
    image = Image.open(request.files["image"])
    image_np = load_image_into_numpy_array(image)
    # Actual detection.
    output_dict = run_inference_for_single_image(image_np, sess, tensor_dict, image_tensor)
    return(json.dumps(output_dict))     

In [None]:
%run detect.py

We will overwrite some functions for local debugging.

In [None]:
from azureml.core.model import Model
# from azureml.contrib.services.aml_request import rawhttp
from PIL import Image
import numpy as np
import tensorflow as tf
import json

def init():   
    global sess, tensor_dict, image_tensor
    
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
#         model_path =  Model.get_model_path("maskrcnn_outofstock_model")
        model_path =  Model.get_model_path("export_maskrcnn/frozen_inference_graph.pb")
        with tf.gfile.GFile(model_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
            sess = tf.Session()
            # Get handles to input and output tensors
            ops = tf.get_default_graph().get_operations()
            all_tensor_names = {output.name for op in ops for output in op.outputs}
            tensor_dict = {}
            for key in [
                'num_detections', 'detection_boxes', 'detection_scores',
                'detection_classes', 'detection_masks'
            ]:
                tensor_name = key + ':0'
                if tensor_name in all_tensor_names:
                    tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
                        tensor_name)
            image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
            
# @rawhttp
def run(request):
#     image = Image.open(request.files)
    image = Image.open(request["image"])
    image_np = load_image_into_numpy_array(image)
    # Actual detection.
    output_dict = run_inference_for_single_image(image_np, sess, tensor_dict, image_tensor)
    return(json.dumps(output_dict))            

In [None]:
init()

In [None]:
request_files = {"image": open('IMAGE HERE', "rb")}
image = Image.open(request_files["image"])
image

In [None]:
output_dict = json.loads(run(request_files))
image_np = load_image_into_numpy_array(image)

In [None]:
PATH_TO_LABELS =  './outofstock_label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)

IMAGE_SIZE = (24, 18)

In [None]:
vis_util.visualize_boxes_and_labels_on_image_array(
    image_np,
    np.array(output_dict['detection_boxes']),
    np.array(output_dict['detection_classes']),
    np.array(output_dict['detection_scores']),
    category_index,
    instance_masks=output_dict.get('detection_masks'),
    max_boxes_to_draw=100,
    min_score_thresh=.5,
    use_normalized_coordinates=True,
    line_thickness=4)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)