# SageMaker Real-time Dynamic Batching Inference with Torchserve

This notebook demonstrates the use of dynamic batching on SageMaker with [torchserve](https://github.com/pytorch/serve/) as a model server. It demonstrates the following
1. Batch inference using DLC i.e. SageMaker's default backend container. This is done by using SageMaker python sdk in script-mode.
2. Specifying inference parameters for torchserve using environment variables.
3. Option to use a custom container with config file for torchserve baked-in the container.

**Installs**

In [None]:
!pip install torch-model-archiver

**Imports**

In [None]:
import base64
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import boto3, time, json
import sagemaker

**Initiate session and retrieve region, account details**

In [None]:
sm_sess = sagemaker.Session()
role = sagemaker.get_execution_role()

In [None]:
sess = boto3.Session()
region = sess.region_name
account = boto3.client("sts").get_caller_identity().get("Account")

In [None]:
bucket = sm_sess.default_bucket()
prefix = "ts-dynamic-batching"
model_name = "BERTSeqClassification"
mar_file = f"{model_name}.mar"

In [None]:
model_artifact = f"s3://{bucket}/{prefix}/models/{model_name}.tar.gz"

## Build a Custom Container

#### This approach uses a custom model config written to config.properties built-in with the container. This model config includes the batch_size, max_batch_delay and other properties to set the batching for the model
### Refer docker/

#### The following script builds a container and pushes it to ECR

In [None]:
%%sh

container_name=custom-dynamic-torchserve
account=$(aws sts get-caller-identity --query Account --output text)

# Get the region defined in the current configuration (default to us-west-2 if none defined)
region=$(aws configure get region)
region=${region:-us-west-2}

fullname="${account}.dkr.ecr.${region}.amazonaws.com/${container_name}"

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${container_name}" > /dev/null 2>&1
if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${container_name}" > /dev/null
fi

# Get the login command from ECR and execute it directly
$(aws ecr get-login --region ${region} --no-include-email)

# Build the docker image locally with the image name and then push it to ECR
# with the full name.
docker build --no-cache -t ${container_name} docker/
docker tag ${container_name} ${fullname}

docker push ${fullname}

**Prepare Data**

In [None]:
!aws s3 cp s3://torchserve/mar_files/{mar_file} .
!tar -cvzf {model_name}.tar.gz {mar_file}
!aws s3 cp {model_name}.tar.gz s3://{bucket}/{prefix}/models/
!rm {mar_file} {model_name}.tar.gz

f"s3://{bucket}/{prefix}/models/{model_name}.tar.gz"

In [None]:
container_name = "custom-dynamic-torchserve"
image_uri = f"{account}.dkr.ecr.{region}.amazonaws.com/{container_name}"

#### Create SageMaker model, deploy and predict

In [None]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.model import Model
from sagemaker.predictor import Predictor

pytorch_model = Model(
    model_data=model_artifact,
    role=role,
    image_uri=image_uri,
    predictor_cls=Predictor,
)

endpoint_name = 'torchserve-endpoint-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

# Change the instance type as necessary, or use 'local' for executing in Sagemaker local mode
instance_type = "ml.c5.9xlarge"
# instance_type = "local"

predictor = pytorch_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.BytesDeserializer(),
    endpoint_name=endpoint_name
)

# Wait for model to load in case of local mode
time.sleep(10)

## Predictions

In [None]:
import multiprocessing


def invoke(num_request):
    return predictor.predict(
        data="{Bloomberg has decided to publish a new report on global economic situation.}"
    )

pool = multiprocessing.Pool(3)
results = pool.map(invoke, range(10))
pool.close()
pool.join()
print(results)

In [None]:
# Clean up
predictor.delete_endpoint(endpoint_name)

## Use AWS Deep Learning Container

#### The AWS DLCs use sagemaker-pytorch-inference-toolkit to set-up and start the model server. Currently, the model-artifacts need to be archived into a *.tar.gz along with a manifest (model metadata) as required by TorchServe

**Prepare Data**

In [None]:
!aws s3 cp s3://torchserve/mar_files/{mar_file} .
!unzip {mar_file}
# Use torch-model-archiver (following command can be used as a reference when using custom models and handlers). Note using option 'no-archive' only generates
# the metadata manifest inside MAR-INF/. This command creates a folder {model_name} i.e. BERTSeqClassification/
!torch-model-archiver --version 1.0 --model-name {model_name} --handler Transformer_handler_generalized.py --serialized-file pytorch_model.bin --extra-files setup_config.json,index_to_name.json,config.json --archive-format no-archive -f
# Sagemaker requires that the models be stored *.tar.gz archive
!tar -cvzf {model_name}.tar.gz -C {model_name}/ .
!aws s3 cp {model_name}.tar.gz s3://{bucket}/{prefix}/models/
!rm -rf {model_name}

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=region,
    py_version="py38",
    image_scope="inference",
    version="1.10",
    instance_type="ml.c5.9xlarge",
)

# We'll use a pytorch inference DLC image that ships with sagemaker-pytorch-inference-toolkit v2.0.10. This version includes support for Torchserve environment variables used below
# PT 1.11 image is released, but not part of python sdk yet
image_uri = image_uri.replace("1.10", "1.11")
print(f"Using image: {image_uri}")

#### Create SageMaker model, deploy and predict

In [None]:
from sagemaker.pytorch.model import PyTorchModel

env_variables_dict = {
    "SAGEMAKER_TS_BATCH_SIZE": "3",
    "SAGEMAKER_TS_MAX_BATCH_DELAY": "100000",
    "SAGEMAKER_TS_MIN_WORKERS": "1",
    "SAGEMAKER_TS_MAX_WORKERS": "1",
}


pytorch_model = PyTorchModel(
    model_data=model_artifact,
    role=role,
    image_uri=image_uri,
    source_dir="code",
    framework_version="1.11",
    env=env_variables_dict,
    entry_point="inference.py",
)

# Change the instance type as necessary, or use 'local' for executing in Sagemaker local mode
instance_type = "ml.c5.9xlarge"
#instance_type = "local"

predictor = pytorch_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.BytesDeserializer(),
)

# Wait for model to load in case of local mode
time.sleep(10)

## Predictions

#### By spawning a pool of 3 processes we're able to simulate requests from multiple clients and verify inference results

In [None]:
import multiprocessing


def invoke(endpoint_name):
    predictor = sagemaker.predictor.Predictor(
        endpoint_name,
        sm_sess,
        serializer=sagemaker.serializers.JSONSerializer(),
        deserializer=sagemaker.deserializers.BytesDeserializer(),
    )
    return predictor.predict(
        "{Bloomberg has decided to publish a new report on global economic situation.}"
    )


endpoint_name = predictor.endpoint_name
pool = multiprocessing.Pool(3)
results = pool.map(invoke, 10 * [endpoint_name])
pool.close()
pool.join()
print(results)

In [None]:
# Clean up
predictor.delete_endpoint(predictor.endpoint_name)

## Conclusion

Through this exercise, we were able to understand the basics of batch inference using torchserve on Amazon SageMaker. We learnt that we can have several inference requests from different processes/users batched together, and the results will be processed as a batch of inputs. We also learnt that we could either use SageMaker's default DLC container as the base environment, or create a custom container that can be used with SageMaker for more involved workflows.