#### Imports 

In [None]:
from IPython.core.display import HTML
import boto3
import json
import os

#### Essentials

In [None]:
# ENDPOINT_NAME = 'ENTER DEPLOYMENT ENDPOINT HERE'
ENDPOINT_NAME = 'jumpstart-ftc-hurricane-damage-classifier'
CONTENT_TYPE = 'application/x-image'
sagemaker = boto3.client('runtime.sagemaker')

In [None]:
label_map = {0: 'damaged', 1: 'not-damaged'}

#### Invoke SageMaker Endpoint and Predict

In [None]:
def query_endpoint(image):
    response = sagemaker.invoke_endpoint(EndpointName=ENDPOINT_NAME, 
                                         ContentType=CONTENT_TYPE, 
                                         Body=image)
    prediction = json.loads(response['Body'].read())
    return prediction

##### Test damaged images

In [None]:
HTML('<table><tr><td><img src="./data/test/damage/1.jpeg" alt="1" style="height: 200px;"/><figcaption>1.jpeg</figcaption>'
     '</td><td> <img src="./data/test/damage/2.jpeg" alt="2" style="height: 200px;"/><figcaption>2.jpeg</figcaption>'
     '</td><td> <img src="./data/test/damage/3.jpeg" alt="3" style="height: 200px;"/><figcaption>3.jpeg</figcaption>'
     '</td><td> <img src="./data/test/damage/4.jpeg" alt="4" style="height: 200px;"/><figcaption>4.jpeg</figcaption>'
     '</td><td> <img src="./data/test/damage/5.jpeg" alt="5" style="height: 200px;"/><figcaption>5.jpeg</figcaption>'
     '</td></tr></table>')

In [None]:
damaged_images = {}

root_dir = './data/test/damage/'
for filename in os.listdir(root_dir):
    if filename.endswith('.jpeg'):
        with open(f'{root_dir}{filename}', 'rb') as file: 
            damaged_images[filename] = file.read()

In [None]:
for filename, image in damaged_images.items():
    prediction = query_endpoint(image)  
    print(f'{filename} = {label_map[prediction.index(max(prediction))]}')
  

##### Test not-damaged images

In [None]:
HTML('<table><tr><td><img src="./data/test/no-damage/1.jpeg" alt="1" style="height: 200px;"/><figcaption>1.jpeg</figcaption>'
     '</td><td> <img src="./data/test/no-damage/2.jpeg" alt="2" style="height: 200px;"/><figcaption>2.jpeg</figcaption>'
     '</td><td> <img src="./data/test/no-damage/3.jpeg" alt="3" style="height: 200px;"/><figcaption>3.jpeg</figcaption>'
     '</td><td> <img src="./data/test/no-damage/4.jpeg" alt="4" style="height: 200px;"/><figcaption>4.jpeg</figcaption>'
     '</td><td> <img src="./data/test/no-damage/5.jpeg" alt="5" style="height: 200px;"/><figcaption>5.jpeg</figcaption>'
     '</td></tr></table>')

In [None]:
not_damaged_images = {}

root_dir = './data/test/no-damage/'
for filename in os.listdir(root_dir):
    if filename.endswith('.jpeg'):
        with open(f'{root_dir}{filename}', 'rb') as file: 
            not_damaged_images[filename] = file.read()

In [None]:
for filename, image in not_damaged_images.items():
    prediction = query_endpoint(image)  
    print(f'{filename} = {label_map[prediction.index(max(prediction))]}')