# Deploy scalable streaming tokens solution on SageMaker
In this notebook, we explore how to host a large language model on SageMaker using the latest container that packages some of the most popular open source libraries for model parallel inference like DeepSpeed and HuggingFace Accelerate. We use DJLServing as the model serving solution in this example. DJLServing is a high-performance universal model serving solution powered by the [Deep Java Library](https://github.com/deepjavalibrary/djl) (DJL) that is programming language agnostic. To learn more about DJL and DJLServing, you can refer to our [recent blog post](https://aws.amazon.com/blogs/machine-learning/deploy-bloom-176b-and-opt-30b-on-amazon-sagemaker-with-large-model-inference-deep-learning-containers-and-deepspeed/).

In this notebook, we will deploy the gpt-neox-7b model on a ml.g5.2xlarge machine. We will also demostrate a streaming experience to have model run end2end in a pagination fashion.


## Licence agreement
- View model license information: Apache 2.0 before using the model.
- This notebook is a sample notebook and not intended for production use. Please refer to the licence at https://github.com/aws/mit-0.


## Permission

In order to conduct this lab, we will need the following permissions:

- ECR Push/Pull access
- S3 bucket push access
- SageMaker access
- DynamoDB access (create DB and query)


## Let's bump up SageMaker and import stuff

In [None]:
%pip install sagemaker boto3 awscli --upgrade  --quiet

In [None]:
import json
import boto3
import sagemaker
from sagemaker import Model, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

## Bring your own container to ECR repository

*Note: Please make sure you have the permission in AWS credential to push to ECR repository*

In this step, we will pull the LMI nightly container from dockerhub and then push it to the ECR repository.

This process may take a while, depends on the container size and your network bandwidth.

In [None]:
%%bash

# The name of our container
repo_name=djlserving-byoc
# Target container
target_container="deepjavalibrary/djl-serving:deepspeed-nightly"

account=$(aws sts get-caller-identity --query Account --output text)

# Get the region defined in the current configuration (default to us-west-2 if none defined)
region=$(aws configure get region)
region=${region:-us-west-2}

fullname="${account}.dkr.ecr.${region}.amazonaws.com/${repo_name}:latest"
echo "Creating ECR repository ${fullname}"

# If the repository doesn't exist in ECR, create it.

aws ecr describe-repositories --repository-names "${repo_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${repo_name}" > /dev/null
fi

# Get the login command from ECR and execute it directly
aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin "${account}.dkr.ecr.${region}.amazonaws.com"

# Build the docker image locally with the image name and then push it to ECR
# with the full name.
echo "Start pulling container: ${target_container}"

docker pull ${target_container}
docker tag ${target_container} ${fullname}
docker push ${fullname}

## Create SageMaker compatible Model artifact, upload Model to S3 and use DJL builtin streaming handler for your model.

SageMaker Large Model Inference containers can be used to host models without providing your own inference code. This is extremely useful when there is no custom pre-processing of the input data or postprocessing of the model's predictions.

However in this notebook, we demonstrate how to deploy a model with custom inference code.

In LMI contianer, we expect some artifacts to help setting up the model
- `serving.properties` is the configuration file that can be used to configure the model server.
- `requirements.txt` (optional) contains the pip wheel need to install in runtime

For more details on the configuration options and an exhaustive list, you can refer the documentation - https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-configuration.html.

In [None]:
%%writefile serving.properties
engine=Python
option.dtype=fp16
option.model_id=stabilityai/stablelm-base-alpha-7b
option.tensor_parallel_degree=1
option.enable_streaming=True
option.low_cpu_mem_usage=True

In [None]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

## Start building SageMaker endpoint

### Upload artifact on S3 and create SageMaker model

The tarball that we created will be sent to an s3bucket that SageMaker created.

In [None]:
s3_code_prefix = "large-model-lmi/nocode"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

repo_name = "djlserving-byoc"
image_uri = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{repo_name}"
env = {"SERVING_DDB_CACHE": "true"}  # use DynamoDB for response caching
# extra env you can set
# SERVING_DDB_BATCH=5 [default] writing to DDB every 5 tokens
# DDB_TABLE_NAME=djl-page [default] DDB name is djl-page
model = Model(image_uri=image_uri, model_data=code_artifact, env=env, role=role)

### Create SageMaker endpoint

Here, we use g5.2xlarge instance. The endpoint name is `lmi-model-deploy`.

#### This step can take ~ 10 min or longer so please be patient

In [None]:
instance_type = "ml.g5.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")
print(f"endpoint_name is {endpoint_name}")

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
)

Let's define a few helper function, you will use `predic_async` to make an initial asynchronous invocation,
and then use `fetch_next_page` to iterate through pages.

In [None]:
def predict_async(input_data, endpoint_name):
    data = serializers.JSONSerializer().serialize(input_data)
    request_args = {
        "CustomAttributes": "x-synchronous=false",
        "EndpointName": endpoint_name,
        "Body": data,
    }

    resp = sess.sagemaker_runtime_client.invoke_endpoint(**request_args)
    return parse_response(resp)


def fetch_next_page(next_page_id, endpoint_name):
    input_data = {"inputs": "fetch"}
    # adds "x-max-items=X", to limit max number items to return per page
    # "x-max-items" is useful for returning image output
    request_args = {
        "CustomAttributes": f"x-starting-token={next_page_id}",
        "EndpointName": endpoint_name,
        "Body": serializers.JSONSerializer().serialize(input_data),
    }
    resp = sess.sagemaker_runtime_client.invoke_endpoint(**request_args)
    return parse_response(resp)


def get_next_page_id(resp):
    custom_attr = resp["ResponseMetadata"]["HTTPHeaders"].get("x-amzn-sagemaker-custom-attributes")
    if custom_attr:
        return custom_attr.split("=")[1]

    return None


def parse_response(resp):
    next_page_id = get_next_page_id(resp)
    response_body = resp["Body"].read().decode("utf-8")
    return (next_page_id, response_body)


def parse_outputs(body):
    result = []
    if body:
        lines = body.rstrip().split("\n")
        for line in lines:
            result.append(json.loads(line).get("outputs"))
    return result


def merge_content(outputs, parsed_outputs):
    for content in parsed_outputs:
        for idx, token in enumerate(content):
            outputs[idx] += token

## Test and benchmark the inference

In here, we use a SageMaker endpoint + DynamoDB simple fetcher to get the response result.

- send prompt request and receive a x-next-token header
- use x-starting-token to retrieve the streamed tokens with pagination

In [None]:
import time

input_text = ["Large language model is", "Amazon is a company"]
input_data = {"inputs": input_text}

outputs = ["" for _ in range(len(input_text))]
next_page_id, body = predict_async(input_data, endpoint_name)

merge_content(outputs, parse_outputs(body))

while next_page_id:
    next_page_id, body = fetch_next_page(next_page_id, endpoint_name)
    merge_content(outputs, parse_outputs(body))
    print("Fetching some tokens after 200ms...")
    time.sleep(0.2)

for content in outputs:
    print("\nGenerated text: " + content)

## Clean up the environment

If you have lambda and API gateway environment, do the following to clean up:

Clean up the SageMaker endpoint:

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()