# MNIST-Azure
## Test Service
### By: Sebastian Goodfellow

In [1]:
# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Import 3rd party libraries
import os
import cv2
import sys
import json
import base64
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt
from matplotlib.image import imread
from azureml.core import Workspace
from azureml.core.model import Model
from ipywidgets import interact, fixed
from ipywidgets.widgets import IntSlider

# Import local Libraries
sys.path.insert(0, './../')
from mnistazure.config import WORKING_PATH, DATA_PATH

# Get Data

In [2]:
# Image shape
image_shape = (28, 28, 1)

# Get list of validation images
file_names = [file_name for file_name in os.listdir(os.path.join(DATA_PATH, 'images')) if 'val' in file_name]
print(file_names[0:5])

['val_0.jpg', 'val_1.jpg', 'val_10.jpg', 'val_100.jpg', 'val_1000.jpg']


In [4]:
# Get workspace
workspace = Workspace.get(name='mnist-azure', subscription_id='', resource_group='')

# Get web service
service = workspace.webservices['mnist-tf']
scoring_uri = service.scoring_uri

In [None]:
def plot_prediction(file_name_id, file_names):

    # Get file name
    file_name = file_names[file_name_id]
    
    # Get prediction
    image_array = imread(os.path.join(DATA_PATH, 'images', file_name)).reshape(image_shape)
    
    test_samples = json.dumps({'data': image_array.tolist()})
    test_samples = bytes(test_samples, encoding='utf8')
    print(test_samples)
    
    result = json.loads(service.run(input_data=test_samples))
    print(result)
    
    # Plot image
    fig = plt.figure(figsize=(5, 5), facecolor='w')
    fig.subplots_adjust(wspace=0, hspace=1.2)
    ax1 = plt.subplot2grid((1, 1), (0, 0))
#     ax1.set_title('Prediction: {}\nScore: {} %'.format(np.argmax(prediction), 
#                                                      int(prediction[0][np.argmax(prediction)] * 100)), 
#                   fontsize=16)
    ax1.imshow(image_array[:, :, 0], cmap='gray', vmin=0, vmax=255)
    ax1.axes.get_xaxis().set_visible(False)
    ax1.axes.get_yaxis().set_visible(False)
    
    plt.show()

    
def get_prediction(file_name):
    image_array = imread(os.path.join(DATA_PATH, 'images', file_name)).reshape(image_shape)
    if graph_type is 'array':
        return image_array, sess.run(fetches=[prediction], feed_dict={images: [image_array]})[0]
    elif graph_type is 'string':
        image = cv2.imencode('.jpg', image_array)[1].tostring()
        # image = open(os.path.join(DATA_PATH, 'images', file_name), 'rb').read()
        return image_array, sess.run(fetches=[prediction], feed_dict={images: [image]})[0]
    

# Launch interactive plotting widget
_ = interact(
    plot_prediction,
    file_name_id=IntSlider(value=0, min=0, max=len(file_names)-1, description='Image ID', disabled=False,),
    file_names=fixed(file_names)
) # 2634