# Chapter-4: Deploy model and infer

In the previous chapter, we trained a model using sagemaker's tensorflow module. in this chapter, we will explore ways to infer from the trained model easily

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
import random

from glob import glob 
from sagemaker.tensorflow import TensorFlowModel

## Setup account details

In [None]:
ACCOUNT = "<account-number>"
BUCKET_NAME = f"s3://{ACCOUNT}-model-bucket"
ROLE_NAME = "notebookAccessRole"
ROLE_ARN = f"arn:aws:iam::{ACCOUNT}:role/{ROLE_NAME}"


## Method #1: Deploy the trained model from within sagemaker instance

In [None]:
# refer to chapter-3 checkpoints or select from your s3 Bucket.
model_location = f"{BUCKET_NAME}/tensorflow-training-2021-05-28-06-32-03-399/output/model.tar.gz"
framework_version = '2.4.1'

model = TensorFlowModel(
    framework_version='2.4.1', 
    role='notebookAccessRole',
    model_data=model_location 
)

In [None]:
estimator = model.deploy(initial_instance_count=1, instance_type='ml.t2.large')

## Predict on test images

In [None]:
ALL_IMAGES = glob('../chapter-3/data/test/*.tiff')
def get_test_data(num_samples=5):
    """ Samples 'num_samples' # of test datasets from the test data split,
    returns the images and the labels
    """
    test_array = []
    bmp_array = []
    random.shuffle(ALL_IMAGES)
    test_images_sampled = ALL_IMAGES[:num_samples]
    print(test_images_sampled)
    for test_image in test_images_sampled:
        image = cv2.imread(test_image)
        image = cv2.resize(image, (256, 256))
        print(image.shape)
        bmp_image = cv2.imread(test_image.replace('.tiff','_bitmap.png'))
        bmp_image = cv2.resize(bmp_image, (256, 256))
        test_array.append(image)
        bmp_array.append(bmp_image)
    return np.asarray(test_array), np.asarray(bmp_array)

In [None]:
modis_batch, bmp_batch = get_test_data()
bmp_predict_batch = np.asarray(estimator.predict(modis_batch)['predictions'])
for j in range(len(modis_batch)):
    bmp_data = bmp_batch[j]
    f, ax = plt.subplots(1, 3, constrained_layout=True, dpi=100)
    ax[0].imshow(modis_batch[j].astype('uint8'))
    ax[0].set_title('RGB Image')
    ax[0].xaxis.set_ticks([])
    ax[0].yaxis.set_ticks([])
    ax[1].imshow(modis_batch[j].astype('uint8'))
    ax[1].xaxis.set_ticks([])
    ax[1].yaxis.set_ticks([])
    ax[1].set_title('SME label overlay')
    ax[2].imshow(modis_batch[j].astype('uint8'))
    ax[2].set_title('Model Prediction overlay')
    ax[2].xaxis.set_ticks([])
    ax[2].yaxis.set_ticks([])
    bmp_data = bmp_batch[j].astype('uint8')
    ax[1].imshow(ma.masked_where(bmp_batch[j] != 0, bmp_batch[j])[:,:,0],alpha=0.35,cmap='Purples')
    ax[2].imshow(ma.masked_where(bmp_predict_batch[j] < 0.5, bmp_predict_batch[j])[:,:,0],alpha=0.45,cmap='spring')

## Method #2: Use the model prediction REST API endpoint

In [None]:
# refer to chapter-3 checkpoints or select from your s3 Bucket.
model_location = f"{BUCKET_NAME}/tensorflow-training-2021-05-05-10-10-34-979/output/model.tar.gz'
bmp_predict_batch = model.predict(modis_batch)

In [None]:
for j in range(len(modis_batch)):
    bmp_data = bmp_batch[j]
    f, ax = plt.subplots(1, 3, constrained_layout=True, dpi=100)
    ax[0].imshow(modis_batch[j].astype('uint8'))
    ax[0].set_title('RGB Image')
    ax[0].xaxis.set_ticks([])
    ax[0].yaxis.set_ticks([])
    ax[1].imshow(modis_batch[j].astype('uint8'))
    ax[1].xaxis.set_ticks([])
    ax[1].yaxis.set_ticks([])
    ax[1].set_title('SME label overlay')
    ax[2].imshow(modis_batch[j].astype('uint8'))
    ax[2].set_title('Model Prediction overlay')
    ax[2].xaxis.set_ticks([])
    ax[2].yaxis.set_ticks([])
    bmp_data = bmp_batch[j].astype('uint8')
    ax[1].imshow(ma.masked_where(bmp_batch[j] != 0, bmp_batch[j])[:,:,0],alpha=0.35,cmap='Purples')
    ax[2].imshow(ma.masked_where(bmp_predict_batch[j] < 0.5, bmp_predict_batch[j])[:,:,0],alpha=0.45,cmap='spring')

# Cleanup

In [None]:
estimator.delete_endpoint()

1. Delete the s3 buckets created
2. Delete the endpoints and deployed models https://docs.aws.amazon.com/sagemaker/latest/dg/ex1-cleanup.html
3. Delete logs in cloudwatch https://console.aws.amazon.com/cloudwatch/home#logsV2:log-groups