# SageMaker Inference Components Deployment

In this notebook we'll utilize SageMaker Inference Components to deploy a Llama and BART model on a singular endpoint. This endpoint we will then utilize in our chatbot for both QnA from Llama 7B Chat and BART for the summarization portion.

### Models Being Utilized
- [Llama-7B-Chat](https://huggingface.co/TheBloke/Llama-2-7B-Chat-fp16): For the QnA aspect of chatbot.
- [Fine-Tuned BART Model](https://huggingface.co/knkarthick/MEETING_SUMMARY): A fine-tuned BART model on the HuggingFace Hub. This has been fine-tuned on SAMSUM dataset as well.
    - License: [Apache 2.0](https://choosealicense.com/licenses/apache-2.0/). No changes done to the model just using the base model provided on the HF Hub.

## Setup

In [None]:
!pip install sagemaker --upgrade --quiet

In [None]:
import boto3
import sagemaker
import time
from time import gmtime, strftime

#Setup
client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")
boto_session = boto3.session.Session()
s3 = boto_session.resource('s3')
region = boto_session.region_name
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()
print(f"Role ARN: {role}")
print(f"Region: {region}")

# client setup
s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

In [None]:

# endpoint config name
epc_name = "ic-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(f"Endpoint Config Name: {epc_name}")

# Container Parameters, increase health check for LLMs: 
variant_name = "AllTraffic"
instance_type = "ml.g5.12xlarge" # 4 GPUs available per instance
model_data_download_timeout_in_seconds = 3600
container_startup_health_check_timeout_in_seconds = 3600

# Setting up managed AutoScaling at endpoint level
initial_instance_count = 1
max_instance_count = 2
print(f"Initial instance count: {initial_instance_count}")
print(f"Max instance count: {max_instance_count}")

# Endpoint Config Creation
endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=epc_name,
    ExecutionRoleArn=role,
    ProductionVariants=[
        {
            "VariantName": variant_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
            "ManagedInstanceScaling": {
                "Status": "ENABLED",
                "MinInstanceCount": initial_instance_count,
                "MaxInstanceCount": max_instance_count,
            },
            # can set to least outstanding or random: https://aws.amazon.com/blogs/machine-learning/minimize-real-time-inference-latency-by-using-amazon-sagemaker-routing-strategies/
            "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
        }
    ],
)

print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

In [None]:
#Endpoint Creation
endpoint_name = "ic-ep-chatbot" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=epc_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
#Monitor creation
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)
print(describe_endpoint_response)

### Llama 7B Chat IC Creation

For Llama 7B Chat we just follow the ready made guide here with the LMI Container: https://github.com/deepjavalibrary/djl-demo/blob/2a5152f578f5954b8b68acdee18eed4e2a75c81f/aws/sagemaker/large-model-inference/sample-llm/rollingbatch_llama_7b_chat.ipynb.

In [None]:
%%writefile serving.properties
engine=MPI
option.model_id=TheBloke/Llama-2-7B-Chat-fp16
option.task=text-generation
option.trust_remote_code=true
option.tensor_parallel_degree=1
option.max_rolling_batch_size=32
option.rolling_batch=lmi-dist
option.dtype=fp16

In [None]:
%%writefile model.py
from djl_python.huggingface import HuggingFaceService
from djl_python import Output
from djl_python.encode_decode import encode, decode
from transformers import AutoTokenizer
import logging
import json
import types

_service = HuggingFaceService()

def custom_parse_input(self, inputs):
    input_data = []
    input_size = []
    parameters = []
    errors = {}
    # used for chat completion
    if self.tokenizer is None:
        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_configs.model_id_or_path)
    batch = inputs.get_batches()
    for i, item in enumerate(batch):
        try:
            content_type = item.get_property("Content-Type")
            input_map = decode(item, content_type)
        except Exception as e:  # pylint: disable=broad-except
            logging.warning(f"Parse input failed: {i}")
            input_size.append(0)
            errors[i] = str(e)
            continue
        # Chat message masssaging
        chat = input_map.pop("chat", [])
        if len(chat) != 0:
            formatted_str = self.tokenizer.apply_chat_template(chat, tokenize=False)
            input_data.extend([formatted_str])
        else:
            input_data.extend([""])
        input_size.append(1)
        # End of massaging
        _param = input_map.pop("parameters", {})
        if not "seed" in _param:
            # set server provided seed if seed is not part of request
            if item.contains_key("seed"):
                _param["seed"] = item.get_as_string(key="seed")
        for _ in range(input_size[i]):
            parameters.append(_param)

    return input_data, input_size, parameters, errors, batch


def chat_output_formatter(token, first_token, last_token, details, generated_tokens):
    """
    json output formatter

    :return: formatted output
    """
    json_encoded_str = f"{{\"role\": \"assistant\", \"content\": \"" if first_token else ""
    json_encoded_str = f"{json_encoded_str}{json.dumps(token.text, ensure_ascii=False)[1:-1]}"
    if last_token:
        if details:
            details_str = f"\"details\": {json.dumps(details, ensure_ascii=False)}"
            json_encoded_str = f"{json_encoded_str}\", {details_str}}}"
        else:
            json_encoded_str = f"{json_encoded_str}\"}}"

    return json_encoded_str


def handle(inputs):
    if not _service.initialized:
        props = inputs.get_properties()
        props["output_formatter"] = chat_output_formatter
        _service.initialize(inputs.get_properties())
        # replace parse_input
        _service.parse_input = types.MethodType(custom_parse_input, _service)

    if inputs.is_empty():
        # initialization request
        return None

    return _service.inference(inputs)

In [None]:
%%sh
mkdir mymodel
rm mymodel.tar.gz
mv serving.properties mymodel/
mv model.py mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

In [None]:
image_uri = sagemaker.image_uris.retrieve(
        framework="djl-deepspeed",
        region=sagemaker_session.boto_session.region_name,
        version="0.26.0"
    )
print(f"Image being used: {image_uri}")

In [None]:
from sagemaker.utils import name_from_base

llama_model_name = name_from_base(f"Llama-7b-chat")
print(llama_model_name)

create_model_response = sm_client.create_model(
    ModelName=llama_model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": image_uri, "ModelDataUrl": code_artifact},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
llama7b_ic_name = "llama7b-chat-ic" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
variant_name = "AllTraffic"

# llama inference component reaction
create_llama_ic_response = sm_client.create_inference_component(
    InferenceComponentName=llama7b_ic_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": llama_model_name,
        "ComputeResourceRequirements": {
            # need just one GPU for llama 7b chat
            "NumberOfAcceleratorDevicesRequired": 1,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    # can setup autoscaling for copies, each copy will retain the hardware you have allocated
    RuntimeConfig={"CopyCount": 1},
)

print("IC Llama Arn: " + create_llama_ic_response["InferenceComponentArn"])

In [None]:
describe_ic_llama_response = client.describe_inference_component(
    InferenceComponentName=llama7b_ic_name)

while describe_ic_llama_response["InferenceComponentStatus"] == "Creating":
    describe_ic_llama_response = client.describe_inference_component(InferenceComponentName=llama7b_ic_name)
    print(describe_ic_llama_response["InferenceComponentStatus"])
    time.sleep(100)
print(describe_ic_llama_response)

#### Sample Inference

In [None]:
import json
content_type = "application/json"
chat = [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
  {"role": "user", "content": "I am software engineer looking to learn more about machine learning."},
]

payload = {"chat": chat, "parameters": {"max_tokens":256, "do_sample": True}}
response = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=llama7b_ic_name, #specify IC name
    ContentType=content_type,
    Body=json.dumps(payload),
    )
result = json.loads(response['Body'].read().decode())
print(type(result['content']))
print(type(result))

In [None]:
print(f"initial chat: {chat}")
chat.append(result) #add dialogue to chat
print(f"updated chat: {chat}")

### BART Summarization Model IC Creation

In [None]:
from sagemaker.utils import name_from_base

bart_model_name = name_from_base(f"bart-summarization")
print(bart_model_name)

# replace with your region if needed
hf_transformers_image_uri = '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04'

# env variables
env = {'HF_MODEL_ID': 'knkarthick/MEETING_SUMMARY',
      'HF_TASK':'summarization',
      'SAGEMAKER_CONTAINER_LOG_LEVEL':'20',
      'SAGEMAKER_REGION':'us-east-1'}

create_model_response = sm_client.create_model(
    ModelName=bart_model_name,
    ExecutionRoleArn=role,
    # in this case no model data point directly towards HF Hub
    PrimaryContainer={"Image": hf_transformers_image_uri, 
                      "Environment": env},
)
model_arn = create_model_response["ModelArn"]
print(f"Created Model: {model_arn}")

In [None]:
bart_ic_name = "bart-summarization-ic" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
variant_name = "AllTraffic"

# BART inference component reaction
create_bart_ic_response = sm_client.create_inference_component(
    InferenceComponentName=bart_ic_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": bart_model_name,
        "ComputeResourceRequirements": {
            # will reserve one GPU
            "NumberOfAcceleratorDevicesRequired": 1,
            "NumberOfCpuCoresRequired": 8,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    # can setup autoscaling for copies, each copy will retain the hardware you have allocated
    RuntimeConfig={"CopyCount": 1},
)

print("IC BART Arn: " + create_bart_ic_response["InferenceComponentArn"])

In [None]:
describe_ic_bart_response = client.describe_inference_component(
    InferenceComponentName=bart_ic_name)

while describe_ic_bart_response["InferenceComponentStatus"] == "Creating":
    describe_ic_bart_response = client.describe_inference_component(InferenceComponentName=bart_ic_name)
    print(describe_ic_bart_response["InferenceComponentStatus"])
    time.sleep(100)
print(describe_ic_bart_response)

#### Sample Inference
Note we want to feed the conversation we have with Llama into this IC for summarization.

In [None]:
# prompt template (can use langchain to make cleaner if you want)
text = ''''''

# prepare payload
for resp in chat:
    if resp['role'] == "user":
        text += f"Ram: {resp['content']}\n"
    elif resp['role'] == "assistant":
        text += f"AI: {resp['content']}\n"
print(text)

In [None]:
payload = {"inputs": text}
response = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=bart_ic_name, #specify IC name
    ContentType=content_type,
    Body=json.dumps(payload),
    )
result = json.loads(response['Body'].read().decode())
print(result[0]['summary_text'])