## SageMaker Asynchronous Inference With LLMs

In this example we'll take a look at using the Asynchronous Inference Hosting Option to host the [Flan LLM](https://huggingface.co/google/flan-t5-xxl).

- Async Inference Launch Code Sample/Reference: https://github.com/aws/amazon-sagemaker-examples/blob/main/async-inference/Async-Inference-Walkthrough-SageMaker-Python-SDK.ipynb
- This example uses some Reference Code from my colleague Abhi Sodhani's original example: https://github.com/abhisodhani/sagemaker-falcon-asyn-hosting/blob/main/notebook/huggingface-large-model-aync-inference.ipynb. We will extend this example for our own LLM use-case/model.

### Setup

In [None]:
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri

try:
	role = sagemaker.get_execution_role()
except ValueError:
	iam = boto3.client('iam')
	role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

print(f"Role: {role}")

In [None]:
sagemaker_session = sagemaker.Session()
default_bucket = sagemaker_session.default_bucket()
bucket_prefix = "async-llm-output"
async_output_path = f"s3://{default_bucket}/{bucket_prefix}/output"
print(f"My model inference outputs will be stored at this S3 path: {async_output_path}")

In [None]:
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig

async_config = AsyncInferenceConfig(
    output_path=async_output_path,
    max_concurrent_invocations_per_instance=10,
    # Optionally specify Amazon SNS topics
    # notification_config = {
    # "SuccessTopic": "arn:aws:sns:<aws-region>:<account-id>:<topic-name>",
    # "ErrorTopic": "arn:aws:sns:<aws-region>:<account-id>:<topic-name>",
    # }
)

In [None]:
# directly grab huggingface hub deploy code and add async config
hub = {
	'HF_MODEL_ID':'google/flan-t5-xxl',
	'SM_NUM_GPUS': json.dumps(4)
}

huggingface_model = HuggingFaceModel(
	image_uri=get_huggingface_llm_image_uri("huggingface",version="1.1.0"),
	env=hub,
	role=role, 
)

In [None]:
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
	initial_instance_count=1,
	instance_type="ml.g5.12xlarge",
	container_startup_health_check_timeout=300,
    async_inference_config=async_config
)

### Sample Inference

In [None]:
# singular invocation

payload = "What is the capitol of the United States?"
input_data = {
    "inputs": payload,
    "parameters": {
        "early_stopping": True,
        "length_penalty": 2.0,
        "max_new_tokens": 50,
        "temperature": 1,
        "min_length": 10,
        "no_repeat_ngram_size": 3,
        },
}
predictor.predict(input_data)

In [None]:
import json
import os

output_directory = 'inputs'
os.makedirs(output_directory, exist_ok=True)

for i in range(1, 20):
    json_data = [input_data.copy()]

    file_path = os.path.join(output_directory, f'input_{i}.jsonl')
    with open(file_path, 'w') as input_file:
        for line in json_data:
            json.dump(line, input_file)
            input_file.write('\n')

In [None]:
bucket_prefix_input = "input-data-llm"
input_location = "inputs.jsonl"

def upload_file(input_location):
    prefix = f"{bucket_prefix}/input"
    return sagemaker_session.upload_data(
        input_location,
        bucket=default_bucket,
        key_prefix=prefix,
        extra_args={"ContentType": "application/json"} #make sure to specify
    )

sample_data_point = upload_file("inputs/input_1.jsonl")
print(f"Sample data point uploaded: {sample_data_point}")

In [None]:
import boto3
runtime = boto3.client("sagemaker-runtime")

response = runtime.invoke_endpoint_async(
    EndpointName=predictor.endpoint_name,
    InputLocation=sample_data_point,
    Accept='application/json',
    ContentType="application/json"
)

output_location = response["OutputLocation"]
print(f"OutputLocation: {output_location}")

In [None]:
import urllib, time
from botocore.exceptions import ClientError

# function reference/credit: https://github.com/aws/amazon-sagemaker-examples/blob/main/async-inference/Async-Inference-Walkthrough-SageMaker-Python-SDK.ipynb
def get_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    while True:
        try:
            return sagemaker_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                print("waiting for output...")
                time.sleep(2)
                continue
            raise

In [None]:
output = get_output(output_location)
print(f"Output: {output}")

In [None]:
inferences = []
for i in range(1,20):
    input_file = f"inputs/input_{i}.jsonl"
    input_file_s3_location = upload_file(input_file)
    print(f"Invoking Endpoint with {input_file}")
    async_response = predictor.predict_async(input_path=input_file_s3_location)
    output_location = async_response.output_path
    print(output_location)
    inferences += [(input_file, output_location)]
    time.sleep(0.5)

for input_file, output_location in inferences:
    output = get_output(output_location)
    print(f"Input File: {input_file}, Output: {output}")

### AutoScaling

In [None]:
client = boto3.client(
    "application-autoscaling"
)  # Common class representing Application Auto Scaling for SageMaker amongst other services

resource_id = (
    "endpoint/" + predictor.endpoint_name + "/variant/" + "AllTraffic"
)  # This is the format in which application autoscaling references the endpoint

# Configure Autoscaling on asynchronous endpoint down to zero instances
response = client.register_scalable_target(
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    MinCapacity=0,
    MaxCapacity=5,
)

response = client.put_scaling_policy(
    PolicyName="Invocations-ScalingPolicy",
    ServiceNamespace="sagemaker",  # The namespace of the AWS service that provides the resource.
    ResourceId=resource_id,  # Endpoint name
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",  # SageMaker supports only Instance Count
    PolicyType="TargetTrackingScaling",  # 'StepScaling'|'TargetTrackingScaling'
    TargetTrackingScalingPolicyConfiguration={
        "TargetValue": 5.0,  # The target value for the metric. - here the metric is - SageMakerVariantInvocationsPerInstance
        "CustomizedMetricSpecification": {
            "MetricName": "ApproximateBacklogSizePerInstance",
            "Namespace": "AWS/SageMaker",
            "Dimensions": [{"Name": "EndpointName", "Value": predictor.endpoint_name}],
            "Statistic": "Average",
        },
        "ScaleInCooldown": 600,  # The cooldown period helps you prevent your Auto Scaling group from launching or terminating
        # additional instances before the effects of previous activities are visible.
        # You can configure the length of time based on your instance startup time or other application needs.
        # ScaleInCooldown - The amount of time, in seconds, after a scale in activity completes before another scale in activity can start.
        "ScaleOutCooldown": 100  # ScaleOutCooldown - The amount of time, in seconds, after a scale out activity completes before another scale out activity can start.
        # 'DisableScaleIn': True|False - ndicates whether scale in by the target tracking policy is disabled.
        # If the value is true , scale in is disabled and the target tracking policy won't remove capacity from the scalable resource.
    },
)

In [None]:
request_duration = 60 * 15 # 15 minutes
end_time = time.time() + request_duration
print(f"test will run for {request_duration} seconds")
while time.time() < end_time:
    predictor.predict(input_data)

In [None]:
sm_client = boto3.client(service_name='sagemaker')
response = sm_client.describe_endpoint(EndpointName=predictorendpoint_name)
status = response['EndpointStatus']
print("Status: " + status)


while status=='Updating':
    time.sleep(60)
    response = sm_client.describe_endpoint(EndpointName=predictor.endpoint_name)
    status = response['EndpointStatus']
    instance_count = response['ProductionVariants'][0]['CurrentInstanceCount']
    print(f"Status: {status}")
    print(f"Current Instance count: {instance_count}")