In [None]:
import boto3, cv2, time, numpy as np, matplotlib.pyplot as plt
from sagemaker.pytorch import PyTorchPredictor
from sagemaker.deserializers import JSONDeserializer

## Run Inference

In [None]:
sm_client = boto3.client(service_name="sagemaker")
response = sm_client.list_endpoints()

endpoint_name = 'ADD THE ENDPOINT NAME AS IT WAS CREATED IN THE PREVIOUS NOTEBOOK'

# while True:
for ep in response['Endpoints']:
    print(f"Endpoint Status = {ep['EndpointStatus']}")
    if ep['EndpointName']==endpoint_name and ep['EndpointStatus']=='InService':
        break
    time.sleep(5)        

In [None]:
predictor = PyTorchPredictor(endpoint_name=endpoint_name,
                             deserializer=JSONDeserializer())

In [None]:
infer_start_time = time.time()

orig_image = cv2.imread('bus.jpg')

image_height, image_width, _ = orig_image.shape
model_height, model_width = 300, 300
x_ratio = image_width/model_width
y_ratio = image_height/model_height

resized_image = cv2.resize(orig_image, (model_height, model_width))
payload = cv2.imencode('.jpg', resized_image)[1].tobytes()
result = predictor.predict(payload)

infer_end_time = time.time()

print(f"Inference Time = {infer_end_time - infer_start_time:0.4f} seconds")

for x1,y1,x2,y2,conf,lbl in result['boxes']:
    x1, x2 = int(x_ratio*x1), int(x_ratio*x2)
    y1, y2 = int(y_ratio*y1), int(y_ratio*y2)
    cv2.rectangle(orig_image, (x1,y1), (x2,y2), (0,255,0), 4)
    cv2.putText(orig_image, f"Class: {int(lbl)}", (x1,y1-40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
    cv2.putText(orig_image, f"Conf: {int(conf*100)}", (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)

In [None]:
plt.imshow(cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB))
plt.show()

## Cleanup by removing Endpoint, Endpoint Config and Model

In [None]:
response = sm_client.describe_endpoint_config(EndpointConfigName=endpoint_name)
print(response)
endpoint_config_name = response['EndpointConfigName']

# Delete Endpoint
sm_client.delete_endpoint(EndpointName=endpoint_name)

# Delete Endpoint Configuration
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

# Delete Model
for prod_var in response['ProductionVariants']:
    model_name = prod_var['ModelName']
    sm_client.delete_model(ModelName=model_name)     