In [None]:
import boto3
import sagemaker

In [None]:
role = 'arn:aws:iam::638608113287:role/service-role/AmazonSageMaker-ExecutionRole-20180731T132167'
sess = sagemaker.Session()

In [None]:
account = sess.boto_session.client('sts').get_caller_identity()['Account']
region = sess.boto_session.region_name

In [None]:
image = '638608113287.dkr.ecr.us-east-1.amazonaws.com/faster-rcnn:gpu'
train_instance_type = 'ml.p2.16xlarge'
instance_count = 1
output_path="s3://model-artifacts-alkymi/faster-rcnn/"
data_location = "s3://training-data-alkymi/pageseg/20190226"

hyperparameters = {
  "batch_size": "64",
  "epochs":"2",
  "lr": "0.0001",
  "lr_decay_gamma": "0.1",
  "lr_patience": "2",
  "patience": "4",
  "imdb_name": "pdfpages",
  "num_workers": "16",
  "USE_FLIPPED": "False"
}

In [None]:
estimator = sagemaker.estimator.Estimator(image,
                                          role, instance_count, train_instance_type,
                                          output_path=output_path,
                                          sagemaker_session=sess, 
                                          hyperparameters=hyperparameters)

In [None]:
estimator.fit(data_location)

In [None]:
deploy_instance_type = 'ml.p2.xlarge'
predictor = estimator.deploy(initial_instance_count=1, instance_type=deploy_instance_type)

In [None]:
import requests
import json
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
file_path = '../data/wf_celg_report.pdf'

with open(file_path, 'rb') as f:
    doc_data = f.read()
params = {'page': 0}
page_image_response = requests.get('https://pdf-service.alkymi.cloud/getPageImage', 
                                   params=params, data=doc_data)
prediction_response = predictor.predict(page_image_response.content)
pred = json.loads(prediction_response)['pred']
img_bytes = BytesIO(page_image_response.content)
img = Image.open(img_bytes)

fig, ax = plt.subplots(figsize=(8.5, 11))
plt.axis('off')

box_type_to_color = {'text':'r', 'graphical_chart':'g', 'structured_data':'b'}
for box_type, boxes in pred.items():
    color = box_type_to_color[box_type]
    for box in boxes:
        rect = patches.Rectangle((float(box[0]), float(box[1])),
                                 float(box[2]) - float(box[0]),
                                 float(box[3]) - float(box[1]),
                                 linewidth=1,
                                 edgecolor=color,
                                 facecolor='none')
        ax.add_patch(rect)
        ax.annotate(round(box[4], 3), 
                    (float(box[0]), float(box[1])), 
                    color=color, 
                    fontsize=12, ha='center', va='center')

ax.imshow(img)
