# Using TensorFlow Extended for serving the model

We start by loading the necessary libraries.

In [None]:
import tensorflow as tf
import numpy as np
import requests
import matplotlib.pyplot as plt
import json

## Build a model
We’ll build a MNIST model using the Keras Sequential API.

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize
x_train = x_train / 255
x_test = x_test/ 255

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(name="FLATTEN"))
model.add(tf.keras.layers.Dense(units=128 , activation="relu", name="D1"))
model.add(tf.keras.layers.Dense(units=64 , activation="relu", name="D2"))
model.add(tf.keras.layers.Dense(units=10, activation="softmax", name="OUTPUT"))
    
model.compile(optimizer="sgd", 
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"]
             )

model.fit(x=x_train, 
          y=y_train, 
          epochs=5,
          validation_data=(x_test, y_test)
         ) 

## Save the entire model as SavedModel format

Then, we will save our model as SavedModel format and create a directory for each version of our model.

In [None]:
# Build a model directory
!mkdir "my_mnist_model"

# Build a directory with the version 1
!mkdir "my_mnist_model/1"

In [None]:
# Save the model
model.save("my_mnist_model/1")

## Download the TensorFlow Serving docker image

We'll install TensorFlow Serving by using Docker.

We encourage reader to visit the official Docker documentation to get Docker installation instructions if needed.

The first step is to pull the latest TensorFlow Serving docker image.

In [None]:
!docker pull tensorflow/serving

Now, we'll start a Docker container :
- publish the REST API port 8501 to our host's port 8501
- take the previous model `my_mnist_model`
- bound it to the model base path `/models/my_mnist_model`
- fill in the environment variable MODEL_NAME with `my_mnist_model`.


_Note that this commands should be executed from the command line bash prompt!_

`docker run -p 8501:8501 \
  --mount type=bind,source="$(pwd)/my_mnist_model/",target=/models/my_mnist_model \
  -e MODEL_NAME=my_mnist_model -t tensorflow/serving`


## Display the images to predict

In [None]:
num_rows = 4
num_cols = 3
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for row in range(num_rows):
    for col in range(num_cols):
        index = num_cols * row + col
        image = x_test[index]
        true_label = y_test[index]
        plt.subplot(num_rows, 2*num_cols, 2*index+1)
        plt.imshow(image.reshape(28,28), cmap="binary")
        plt.axis('off')
        plt.title('\n\n It is a {}'.format(y_test[index]), fontdict={'size': 16})
plt.tight_layout()
plt.show()

## Send POST predict requests to our TensorFlow Serving

We'll send POST predict requests to our server and pass the five images.

The server will return for each image ten probabilities corresponding to the probability for each digit between 0 and 9. 

In [None]:
json_request = '{{ "instances" : {} }}'.format(x_test[0:12].tolist())
resp = requests.post('http://localhost:8501/v1/models/my_mnist_model:predict', data=json_request, headers = {"content-type": "application/json"})
print('response.status_code: {}'.format(resp.status_code))     
print('response.content: {}'.format(resp.content))
predictions = json.loads(resp.text)['predictions']

Then, we will display the results.

In [None]:
num_rows = 4
num_cols = 3
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for row in range(num_rows):
    for col in range(num_cols):
        index = num_cols * row + col
        image = x_test[index]
        predicted_label = np.argmax(predictions[index])
        true_label = y_test[index]
        plt.subplot(num_rows, 2*num_cols, 2*index+1)
        plt.imshow(image.reshape(28,28), cmap="binary")
        plt.axis('off')
        if predicted_label == true_label:
            color = 'blue'
        else:
            color = 'red'
        plt.title('\n\n The model predicts a {} \n and it is a {}'.format(predicted_label, true_label), fontdict={'size': 16}, color=color)
plt.tight_layout()
plt.show()