# Sticky Session Routing SageMaker LMI Container
Another feature of the LMI v16 container is support for sticky session routing which enables stateful GenAI applications. With Sticky Session Routing, requests from the same session are routed to the same instance. This allows for your application to reuse previously processed information to reduce latency and improve user experience.

### Additional Resources/Credits
- Initially Launched with the TorchServe Container: https://aws.amazon.com/blogs/machine-learning/build-ultra-low-latency-multimodal-generative-ai-applications-using-sticky-session-routing-in-amazon/
- LMI v16 Official Docs/NB (most of the code is borrowed from here): https://github.com/deepjavalibrary/djl-demo/blob/master/aws/sagemaker/large-model-inference/sample-llm/stateful_inference_llama3_8b.ipynb

## Setup

In [None]:
%pip install sagemaker --upgrade --quiet --no-warn-conflicts

In [None]:
import json
import sagemaker
import boto3
import time
from time import gmtime, strftime

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"sagemaker version: {sagemaker.__version__}")

## Deploy SM Qwen Endpoint using LMI v16 Container

In [None]:
#specify hardware
instance_type = "ml.g5.4xlarge"
num_gpu = 1

# specify container LMIv16
CONTAINER_VERSION = "0.34.0-lmi16.0.0-cu128"
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"
print(f"Using image URI: {inference_image}")

#utilize the vLLM async handler: 
vllm_env = {
    "HF_MODEL_ID": "Qwen/Qwen3-1.7B",
    "HF_TOKEN": "Enter HF Token here",
    "SERVING_FAIL_FAST": "true",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_TENSOR_PARALLEL_DEGREE": json.dumps(num_gpu),
    "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
    "OPTION_TRUST_REMOTE_CODE": "true",
    "OPTION_SESSIONS_EXPIRATION": "3600" #session expiration, specifies time in seconds a session remains valid before it expires, defaults to 1200
}

In [None]:
#Step 1: Model Creation
model_name = "lmi-sticky-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image,
        "Environment": vllm_env,
    },
)
print("Model Arn: " + create_model_response["ModelArn"])

#Step 2: EPC Creation
epc_name = "lmi-sticky-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=epc_name,
    ProductionVariants=[
        {
            "VariantName": "AllTraffic",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 2,
            "ModelDataDownloadTimeoutInSeconds": 1800,
            "ContainerStartupHealthCheckTimeoutInSeconds": 1800,
        },
    ],
)
print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

#Step 3: EP Creation
endpoint_name = "lmi-sticky-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=epc_name
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

#Monitor ep creation
describe_endpoint_response = sm_client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = sm_client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(60)
print(describe_endpoint_response)

## Start Session

In [None]:
payload = {
    "requestType": "NEW_SESSION"
}
payload = json.dumps(payload)

create_session_response = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=payload,
    ContentType="application/json",
    SessionId="NEW_SESSION")

session_id = create_session_response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-new-session-id'].split(';')[0]
print(f"Created Session ID: {session_id}")

## Invoke EP

In [None]:
response_model = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps({"inputs": "What is Amazon SageMaker?"}),
    ContentType="application/json",
    SessionId=session_id
)
result = json.loads(response_model['Body'].read().decode())['generated_text']
print(result)

## Close Session

In [None]:
payload = {
    "requestType": "CLOSE"
}
payload = json.dumps(payload)

close_session_response = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=payload,
    ContentType="application/json",
    SessionId=session_id)

In [None]:
closed_session_id = close_session_response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-closed-session-id']

print(f"closed_session_id: {closed_session_id}")

## Can't Invoke Closed Session
Here we see when we try invoke the session we closed we are unable to as it's been terminated.

In [None]:
response_model = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps({"inputs": "What is Amazon SageMaker?"}),
    ContentType="application/json",
    SessionId=session_id
)
result = json.loads(response_model['Body'].read().decode())['generated_text']
print(result)