In [None]:
from pai_tf_predict_proto import tf_predict_pb2
import numpy as np
import pickle
import re
import requests
from matplotlib import pyplot as plt

# List of CIFAR-10 classes
classes = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]

# Helper function to find the index of the maximum value 
# in a Python list
def argmax(mylist):
    maxval = max(mylist)
    for i in range(len(mylist)):
        if mylist[i] == maxval:
            return i

# Helper function to unpickle CIFAR-10 data
def unpickle(file):
    with open(file, 'rb') as fo:
        unpacked = pickle.load(fo, encoding='bytes')
    return unpacked

# Helper function to convert image data into a usable format
def reformatImage(images, index):
    image = images[index].reshape(3, 32,32)
    
    # Transpose into [32, 32, 3] (put RGB channel data "at the back")
    image = np.transpose(image, axes=[1,2,0])

    return image

# Read in test data and labels
test_batch = unpickle('test_batch')

# Pull out the 'data' and 'labels' sections
data = test_batch[b'data']
labels = test_batch[b'labels']

# Call our PAI-EAS API on the first 25 test images
for idx in range(0,25):

    # Pull out a single image
    img = reformatImage(data, idx)

    # Normalize image into range [0, 1]
    img = img / 255.0
    
    # Set up serializer, to get our image into an appropriate
    # format to pass to the EAS API endpoint
    request = tf_predict_pb2.PredictRequest()
    
    # 
    # NOTE, you need to replace 'serving_default', and conv2d_6_input' below with the appropriate field
    # name returned by PAI EAS when you run:
    #
    # curl <path_to_api_endpoint> -H 'Authorization: <your_secret_token>' | jq
    #
    # Note that if you don't have 'jq' installed, you can leave off the '| jq', but the 
    # results will be a little harder to read since the JSON output won't be nicely formatted
    #
    # See this page for more details on what the output of 'curl' should look like:
    # 
    #
    request.signature_name = 'serving_default'
    request.inputs['conv2d_6_input'].dtype = tf_predict_pb2.DT_FLOAT  # The type of the images parameter.
    request.inputs['conv2d_6_input'].array_shape.dim.extend([1, 32, 32, 3])  # The shape of the images parameter.
    request.inputs['conv2d_6_input'].float_val.extend(img.reshape(3072))  # The data about the images parameter.
    
    # Serialize data in the Protocol Buffers format to a string and transfer the string.
    request_data = request.SerializeToString()

    #
    # Make the API call. Don't forget to replace 'your_eas_url_here' 
    # and 'your_eas_secret_token_here' with the values returned
    # after your PAI-EAS instance is created
    #
    url = 'your_eas_url_here'
    headers = {"Authorization": 'your_eas_secret_token_here'}
    resp = requests.post(url, data=request_data, headers=headers)

    # We still have to do a little manual work to 
    # determine the predicted label for our image, since
    # the response is a 1x10 array of floating point values, 
    # and the "predicted class" is the index of the maximum
    # floating point value in that array. Making things even
    # trickier, the 1x10 array is encoded in a JSON-like structure
    # where each entry is prefixed with "float_val", so we have to 
    # use a regular expression to collect those "float_val" fields
    # into a list, then convert the resulting strings into 
    # floating point values, hence the complicated code below.
    
    # Convert response into usable data
    response = tf_predict_pb2.PredictResponse()
    response.ParseFromString(resp.content)
    vals = str(response)
    vals = re.findall('float_val: (.+)', vals)
    vals = [float(x) for x in vals]
    predicted_class = argmax(vals)
    prediced_class = classes[predicted_class]
    
    # Print out actual and predicted classes
    print("Predicted class: {}".format(classes[predicted_class]))
    print("Actual class: {}".format(classes[labels[idx]]))
    
    # Display image
    plt.figure(figsize=(2,2))
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(img)
    plt.show() 
    