In [None]:
import sagemaker
from sagemaker import get_execution_role

print(sagemaker.__version__)
 
role = get_execution_role()
session = sagemaker.Session()

In [None]:
bucket = session.default_bucket()

prefix = 'pascalvoc'

s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

print(s3_output_location)

In [None]:
# Update these settings with your own subnets and security group

file_system_id = 'fs-fe36ef34'
subnets = ['subnet-63715206', 'subnet-cbf5bdbc', 'subnet-59395b00']
security_group_ids = ['sg-0aa0a1c297a49e911']

In [None]:
from sagemaker.inputs import FileSystemInput

efs_train_data = FileSystemInput(file_system_id=file_system_id,
                             file_system_type='EFS',
                             directory_path='/input/train')

efs_validation_data = FileSystemInput(file_system_id=file_system_id,
                             file_system_type='EFS',
                             directory_path='/input/validation')

data_channels = {'train': efs_train_data, 'validation': efs_validation_data }

In [None]:
import boto3
from sagemaker import image_uris

region = boto3.Session().region_name    
container = image_uris.retrieve('object-detection', region)
print(container)

In [None]:
role = get_execution_role()

od = sagemaker.estimator.Estimator(container,
                                   role,
                                   instance_count=1,
                                   instance_type='ml.p3.2xlarge',
                                   output_path=s3_output_location,
                                   subnets=subnets,
                                   security_group_ids=security_group_ids)

In [None]:
od.set_hyperparameters(base_network='resnet-50',
                       use_pretrained_model=1,
                       num_classes=20,
                       epochs=1,
                       num_training_samples=16551,
                       mini_batch_size=90)

In [None]:
od.fit(inputs=data_channels)

In [None]:
od_predictor = od.deploy(initial_instance_count = 1, instance_type = 'ml.c5.2xlarge')

In [None]:
!wget -O test.jpg https://upload.wikimedia.org/wikipedia/commons/6/67/Chin_Village.jpg

In [None]:
import boto3, json
import numpy as np

runtime = boto3.Session().client(service_name='runtime.sagemaker')

with open('test.jpg', 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)

response = runtime.invoke_endpoint(EndpointName=od_predictor.endpoint_name, 
                                   ContentType='image/jpeg', 
                                   Body=payload)

response = response['Body'].read()
response = json.loads(response)

In [None]:
print(response)

In [None]:
def visualize_detection(img_file, dets, classes=[], thresh=0.6):
        """
        visualize detections in one image
        Parameters:
        ----------
        img : numpy.array
            image, in bgr format
        dets : numpy.array
            ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
            each row is one object
        classes : tuple or list of str
            class names
        thresh : float
            score threshold
        """
        import random
        import matplotlib.pyplot as plt
        import matplotlib.image as mpimg

        img=mpimg.imread(img_file)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for det in dets:
            (klass, score, x0, y0, x1, y1) = det
            if score < thresh:
                continue
            cls_id = int(klass)
            if cls_id not in colors:
                colors[cls_id] = (random.random(), random.random(), random.random())
            xmin = int(x0 * width)
            ymin = int(y0 * height)
            xmax = int(x1 * width)
            ymax = int(y1 * height)
            rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                 ymax - ymin, fill=False,
                                 edgecolor=colors[cls_id],
                                 linewidth=3.5)
            plt.gca().add_patch(rect)
            class_name = str(cls_id)
            if classes and len(classes) > cls_id:
                class_name = classes[cls_id]
            plt.gca().text(xmin, ymin - 2,
                            '{:s} {:.3f}'.format(class_name, score),
                            bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                                    fontsize=12, color='white')
        plt.show()

In [None]:
%matplotlib inline 

object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 
                     'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 
                     'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']

# Setting a threshold 0.20 will only plot detection results that have a confidence score greater than 0.20.
threshold = 0.30

# Visualize the detections.
visualize_detection('test.jpg', response['prediction'], object_categories, threshold)

In [None]:
od_predictor.delete_endpoint()