## Perform Batch Inference (Predictions) using SageMaker Batch Transform

### Imports 

In [2]:
from sagemaker import get_execution_role
from time import gmtime, strftime
import pandas as pd
import sagemaker
import boto3
import time

### 1. Essentials

In [3]:
BUCKET = sagemaker.Session().default_bucket()
PREFIX = 'clf'
REGION = 'us-east-1'

In [4]:
batch_input = f's3://{BUCKET}/{PREFIX}/batch_test/'
batch_input

's3://sagemaker-us-east-1-119174016168/clf/batch_test/'

In [5]:
batch_output = f's3://{BUCKET}/{PREFIX}/batch_test_out/'
batch_output

's3://sagemaker-us-east-1-119174016168/clf/batch_test_out/'

In [6]:
current_timestamp = strftime("%Y-%m-%d-%H-%M-%S", gmtime())

In [7]:
TRAINING_JOB_NAME = 'classifier-2022-10-04-20-54-21-138'  # Copy this from the console
MODEL_NAME = f'clf-xgboost-model-{current_timestamp}'
BATCH_JOB_NAME = f'clf-xgboost-batch-job-{current_timestamp}'

session = boto3.Session()
sagemaker_execution_role = get_execution_role()
sagemaker_session = sagemaker.session.Session()
sagemaker_client = boto3.client('sagemaker', region_name=REGION)
s3_client = boto3.client('s3')

container_uri = sagemaker.image_uris.retrieve(region=session.region_name, 
                                              framework='xgboost', 
                                              version='1.0-1', 
                                              image_scope='training')
container_uri

'683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3'

### 2. Create a Model object using previously run training job name

In [8]:
info = sagemaker_client.describe_training_job(TrainingJobName=TRAINING_JOB_NAME)
info

{'TrainingJobName': 'classifier-2022-10-04-20-54-21-138',
 'TrainingJobArn': 'arn:aws:sagemaker:us-east-1:119174016168:training-job/classifier-2022-10-04-20-54-21-138',
 'ModelArtifacts': {'S3ModelArtifacts': 's3://sagemaker-us-east-1-119174016168/clf/model-artifacts/classifier-2022-10-04-20-54-21-138/output/model.tar.gz'},
 'TrainingJobStatus': 'Completed',
 'SecondaryStatus': 'Completed',
 'HyperParameters': {'num_round': '100', 'objective': 'binary:logistic'},
 'AlgorithmSpecification': {'TrainingImage': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3',
  'TrainingInputMode': 'File',
  'MetricDefinitions': [{'Name': 'train:mae',
    'Regex': '.*\\[[0-9]+\\].*#011train-mae:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*'},
   {'Name': 'validation:aucpr',
    'Regex': '.*\\[[0-9]+\\].*#011validation-aucpr:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*'},
   {'Name': 'validation:f1_binary',
    'Regex': '.*\\[[0-9]+\\].*#011validation-f1_binary:([-+]?[0-9]*\\

In [9]:
model_artifact_url = info['ModelArtifacts']['S3ModelArtifacts']
model_artifact_url

's3://sagemaker-us-east-1-119174016168/clf/model-artifacts/classifier-2022-10-04-20-54-21-138/output/model.tar.gz'

In [10]:
primary_container = {
    'Image': container_uri,
    'ModelDataUrl': model_artifact_url
  }

In [11]:
response = sagemaker_client.create_model(
    ModelName=MODEL_NAME,
    ExecutionRoleArn=sagemaker_execution_role,
    PrimaryContainer=primary_container)

In [12]:
response

{'ModelArn': 'arn:aws:sagemaker:us-east-1:119174016168:model/clf-xgboost-model-2022-10-04-21-07-13',
 'ResponseMetadata': {'RequestId': '1efebbf6-62b6-413b-ab5c-091f1cb7663c',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '1efebbf6-62b6-413b-ab5c-091f1cb7663c',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '99',
   'date': 'Tue, 04 Oct 2022 21:07:24 GMT'},
  'RetryAttempts': 0}}

### 3. Create a Batch Transformer for Inference

In [13]:
request = {
    "TransformJobName": BATCH_JOB_NAME,
    "ModelName": MODEL_NAME,
    "BatchStrategy": "MultiRecord",
    "TransformOutput": {
        "S3OutputPath": batch_output
    },
    "TransformInput": {
        "DataSource": {
            "S3DataSource": {
                "S3DataType": "S3Prefix",
                "S3Uri": batch_input 
            }
        },
        "ContentType": "text/csv",
        "SplitType": "Line",
        "CompressionType": "None"
    },
    "TransformResources": {
            "InstanceType": "ml.m5.xlarge",
            "InstanceCount": 1
    }
}

In [14]:
response = sagemaker_client.create_transform_job(**request)
response

{'TransformJobArn': 'arn:aws:sagemaker:us-east-1:119174016168:transform-job/clf-xgboost-batch-job-2022-10-04-21-07-13',
 'ResponseMetadata': {'RequestId': '2f572e18-4a3f-4bd8-94da-fe4b34a0629a',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '2f572e18-4a3f-4bd8-94da-fe4b34a0629a',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '118',
   'date': 'Tue, 04 Oct 2022 21:07:30 GMT'},
  'RetryAttempts': 0}}

In [15]:
while True:
    response = sagemaker_client.describe_transform_job(TransformJobName=BATCH_JOB_NAME)
    status = response['TransformJobStatus']
    if  status == 'Completed':
        print("Transform job ended with status: {}".format(status))
        break
    if status == 'Failed':
        message = response['FailureReason']
        print('Transform failed with the following error: {}'.format(message))
        raise Exception('Transform job failed') 
    print("Transform job is still in status: {}".format(status))    
    time.sleep(30) 

Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job is still in status: InProgress
Transform job ended with status: Completed


### 4. Evaluate Output

In [16]:
key = f'{PREFIX}/batch_test_out/batch_test.csv.out'

In [17]:
obj = s3_client.get_object(Bucket=BUCKET, Key=key)
results_df = pd.read_csv(obj['Body'], names=['Predictions'])

In [18]:
results_df

Unnamed: 0,Predictions
0,0.127000
1,0.108036
2,0.078757
3,0.263060
4,0.046278
...,...
2482,0.465947
2483,0.080168
2484,0.123397
2485,0.027480
