# MNIST-Azure
## Test Inference Graph
### By: Sebastian Goodfellow

In [85]:
# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Import 3rd party libraries
import os
import cv2
import sys
import json
import base64
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt
from matplotlib.image import imread
from azureml.core import Workspace
from azureml.core.model import Model
from ipywidgets import interact, fixed
from ipywidgets.widgets import IntSlider

# Import local Libraries
sys.path.insert(0, '/home/sebastiangoodfellow/Documents/Code/mnist-azure')
from mnistazure.config import WORKING_PATH, DATA_PATH

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Get Data

In [19]:
# Image shape
image_shape = (28, 28, 1)

# Get list of validation images
file_names = [file_name for file_name in os.listdir(os.path.join(DATA_PATH, 'images')) if 'val' in file_name]
print(file_names[0:5])

['val_2670.jpg', 'val_4752.jpg', 'val_7445.jpg', 'val_7043.jpg', 'val_3452.jpg']


# Build Inference Graph

In [112]:
# Get workspace
ws = Workspace.get(name='mnist-azure', subscription_id='', 
                   resource_group='')

# Get model
model = Model(workspace=ws, name='mnist_tf_model', version=4)

# Set model path
model_path = os.path.join(WORKING_PATH, 'assets')

# Set graph type
graph_type = 'string'

# Download model files
model.download(target_dir=model_path, exist_ok=True)

# Start session
tf.reset_default_graph()
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

# Import meta graph
saver = tf.train.import_meta_graph(os.path.join(model_path, 'outputs', 'graphs', 
                                                'inference_graph_{}.meta'.format(graph_type)))

# Get graph
graph = tf.get_default_graph()

# Get input tensor
images = graph.get_tensor_by_name(name='images:0')

# Get output tensor
prediction = graph.get_tensor_by_name(name='prediction:0')

# Initialize global variables
sess.run(tf.global_variables_initializer())

# Restore graph variables from checkpoint
saver.restore(sess=sess, save_path=os.path.join(model_path, 'outputs', 'checkpoints', 'model'))

INFO:tensorflow:Restoring parameters from /home/sebastiangoodfellow/Documents/Code/mnist-azure/assets/outputs/checkpoints/model


In [119]:
def plot_prediction(file_name_id, file_names):

    # Get prediction
    image, prediction = get_prediction(file_name=file_names[file_name_id])
    
    # Plot image
    fig = plt.figure(figsize=(5, 5), facecolor='w')
    fig.subplots_adjust(wspace=0, hspace=1.2)
    ax1 = plt.subplot2grid((1, 1), (0, 0))
    ax1.set_title('Prediction: {}\nScore: {} %'.format(np.argmax(prediction), 
                                                     int(prediction[0][np.argmax(prediction)] * 100)), 
                  fontsize=16)
    ax1.imshow(image[:, :, 0], cmap='gray', vmin=0, vmax=255)
    ax1.axes.get_xaxis().set_visible(False)
    ax1.axes.get_yaxis().set_visible(False)
    
    plt.show()

    
def get_prediction(file_name):
    image_array = imread(os.path.join(DATA_PATH, 'images', file_name)).reshape(image_shape)
    if graph_type is 'array':
        return image_array, sess.run(fetches=[prediction], feed_dict={images: [image_array]})[0]
    elif graph_type is 'string':
        image = cv2.imencode('.jpg', image_array)[1].tostring()
        # image = open(os.path.join(DATA_PATH, 'images', file_name), 'rb').read()
        return image_array, sess.run(fetches=[prediction], feed_dict={images: [image]})[0]
    

# Launch interactive plotting widget
_ = interact(
    plot_prediction,
    file_name_id=IntSlider(value=0, min=0, max=len(file_names)-1, description='Image ID', disabled=False,),
    file_names=fixed(file_names)
) # 2634

interactive(children=(IntSlider(value=0, description='Image ID', max=9999), Output()), _dom_classes=('widget-i…