In [2]:
import boto3
import time

In [3]:
role = 'arn:aws:iam::638608113287:role/service-role/AmazonSageMaker-ExecutionRole-20180731T132167'
client = boto3.client('sagemaker')

In [4]:
training_job_name = 'faster-rcnn-align-2019-07-29'

info = client.describe_training_job(TrainingJobName=training_job_name)
model_data_url = info['ModelArtifacts']['S3ModelArtifacts']
# image = info['AlgorithmSpecification']['TrainingImage'] # or inference image
image = '638608113287.dkr.ecr.us-east-1.amazonaws.com/faster-rcnn:1.0-align'
model_name = training_job_name

In [5]:
# Create the model

primary_container = {
    'Image': image,
    'ModelDataUrl': model_data_url
}

create_model_response = client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container)

print(create_model_response['ModelArn'])

arn:aws:sagemaker:us-east-1:638608113287:model/faster-rcnn-align-2019-07-29


In [5]:
# Create endpoint config

instance_type = 'ml.p2.xlarge'
instance_count = 1

endpoint_config_name = f'{training_job_name}-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
print(endpoint_config_name)

create_endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType': instance_type,
        'InitialInstanceCount': instance_count,
        'ModelName':model_name,
        'VariantName':'AllTraffic'}])

print("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

faster-rcnn-align-2019-07-29-2019-07-31-15-25-38
Endpoint Config Arn: arn:aws:sagemaker:us-east-1:638608113287:endpoint-config/faster-rcnn-align-2019-07-29-2019-07-31-15-25-38


In [None]:
# Create endpoint

endpoint_name = f'{training_job_name}-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
print(endpoint_name)

create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name)
print(create_endpoint_response['EndpointArn'])

resp = client.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Status: " + status)

try:
    client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)
finally:
    resp = client.describe_endpoint(EndpointName=endpoint_name)
    status = resp['EndpointStatus']
    print("Arn: " + resp['EndpointArn'])
    print("Create endpoint ended with status: " + status)

    if status != 'InService':
        message = client.describe_endpoint(EndpointName=endpoint_name)['FailureReason']
        print('Create endpoint failed with the following error: {}'.format(message))
        raise Exception('Endpoint creation did not succeed')

In [None]:
import requests
import json
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import sagemaker

predictor = sagemaker.RealTimePredictor(endpoint_name)

file_path = '../data/wf_celg_report.pdf'

with open(file_path, 'rb') as f:
    doc_data = f.read()
params = {'page': 0}
page_image_response = requests.get('https://pdf-service.alkymi.cloud/getPageImage', 
                                   params=params, data=doc_data)
prediction_response = predictor.predict(page_image_response.content)
pred = json.loads(prediction_response)['pred']
img_bytes = BytesIO(page_image_response.content)
img = Image.open(img_bytes)

fig, ax = plt.subplots(figsize=(8.5, 11))
plt.axis('off')

box_type_to_color = {'text':'r', 'graphical_chart':'g', 'structured_data':'b'}
for box_type, boxes in pred.items():
    color = box_type_to_color[box_type]
    for box in boxes:
        rect = patches.Rectangle((float(box[0]), float(box[1])),
                                 float(box[2]) - float(box[0]),
                                 float(box[3]) - float(box[1]),
                                 linewidth=1,
                                 edgecolor=color,
                                 facecolor='none')
        ax.add_patch(rect)
        ax.annotate(round(box[4], 3), 
                    (float(box[0]), float(box[1])), 
                    color=color, 
                    fontsize=12, ha='center', va='center')

ax.imshow(img)