In [None]:
import boto3
import sagemaker
import json
from urllib.parse import urlparse
import time

In [None]:
input_bucket_name = 'training-data-alkymi'
validation_folder = f'pageseg/5/validation'
annotation_folder = f'pageseg/5/validation_annotation'

output_bucket = 'batch-transform-results'
batch_job_name = "notebook-batch-transform6"
output_location = f's3://{output_bucket}/{batch_job_name}'

model_name = 'faster-rcnn-2019-03-06-4-4'
instance_count = 20

In [None]:
# Create a transform job
sm = boto3.client('sagemaker')
s3 = boto3.client('s3')

In [None]:
request = \
{
    "TransformJobName": batch_job_name,
    "ModelName": model_name,
#     "MaxConcurrentTransforms": 4,
    "MaxPayloadInMB": 6,
    "BatchStrategy": "SingleRecord",
    "TransformOutput": {
        "S3OutputPath": output_location
    },
    "TransformInput": {
        "DataSource": {
            "S3DataSource": {
                "S3DataType": "S3Prefix",
                "S3Uri": validation_location 
            }
        },
        "ContentType": "image/png",
#         "SplitType": "Line",
#         "CompressionType": "None"
    },
    "TransformResources": {
            "InstanceType": "ml.m4.2xlarge",
            "InstanceCount": instance_count
    }
}

In [None]:
sm.create_transform_job(**request)
print("Created Transform job with name: ", batch_job_name)

In [None]:
### Wait until the job finishes
while(True):
    response = sm.describe_transform_job(TransformJobName=batch_job_name)
    status = response['TransformJobStatus']
    if  status == 'Completed':
        print("Transform job ended with status: " + status)
        break
    if status == 'Failed':
        message = response['FailureReason']
        print('Transform failed with the following error: {}'.format(message))
        raise Exception('Transform job failed') 
    print("Transform job is in status: " + status)    
    time.sleep(30)    

In [None]:
s3 = boto3.resource('s3')
output_bucket = s3.Bucket(output_bucket)
results = {obj.key: json.loads(obj.get()['Body'].read().decode('utf-8'))['pred'] 
           for obj in output_bucket.objects.all() 
           if batch_job_name in obj.key}

In [None]:
# s3 = boto3.resource('s3')
# input_bucket = s3.Bucket(input_bucket_name)
for key in results.keys():
    m = re.search(r"/(.*).png", key)
    file_base = m.groups()[0]
    annotation_file_key  = f'{annotation_folder}/{file_base}.xml'
    obj = s3.get_object(Bucket=input_bucket_name, Key=annotation_file_key )
    xml_annotation = obj['Body'].read().decode('utf-8')
    break

In [None]:
### Fetch the transform output
output_key = "{}/valid_data.csv.out".format(urlparse(output_location).path.lstrip('/'))
s3_client.download_file(bucket, output_key, 'valid-result')
with open('valid-result') as f:
    results = f.readlines()   
print("Sample transform result: {}".format(results[0]))