# Bolt classifier client


In [None]:
import operator
import os
import random
import subprocess

import numpy
from grpc.beta import implementations

import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc

from PIL import Image
from matplotlib.pyplot import imshow

In [None]:
def load_labels(label_file):
    label = []
    proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
    for l in proto_as_ascii_lines:
        label.append(l.rstrip())
    return label

def classify_bolt(service_host, service_port, file_path):
    labels = load_labels("labels.txt")

    raw_image = Image.open(file_path)
    raw_image = raw_image.resize((224,224), Image.ANTIALIAS)
    image = numpy.array(raw_image).astype(numpy.float32)/255

    channel = implementations.insecure_channel(service_host, service_port)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel._channel)

    request = predict_pb2.PredictRequest()
    request.model_spec.name = "model-server-workflow-v1"
    request.model_spec.signature_name = "serving_default"
    request.inputs['image'].CopyFrom(
            tf.make_tensor_proto(image, shape=[1, 224, 224, 3]))

    result = stub.Predict(request, 10.0)  # 10 secs timeout

    plist = result.outputs["prediction"].float_val
    index, value = max(enumerate(plist), key=operator.itemgetter(1))
    print("type: {}, score: {}".format(labels[index], value))


## Model service endpoint

In [None]:
svc = subprocess.check_output(["bash","-c", "kubectl get svc -n kubeflow | grep v1-bolts | tail -1 | awk '{print $1}'"])
service_host = "{}.kubeflow.svc.cluster.local".format(svc.decode('UTF-8').rstrip())
service_port = 9000
print("{}:{}".format(service_host, service_port))

## Test case 1

In [None]:
%matplotlib inline
test_image_1 = "bxm/201721882325_125_9.jpg"
imshow(numpy.asarray(Image.open(test_image_1, 'r')))
classify_bolt(service_host,service_port,test_image_1)

## Test case 2

In [None]:
%matplotlib inline
test_image_2 = "3pax/201721892859_115_9.jpg"
imshow(numpy.asarray(Image.open(test_image_2, 'r')))
classify_bolt(service_host,service_port,test_image_2)