# LLAMA 7B Chat inference using SageMaker LMI 
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

## Step 1: Let's bump up SageMaker and import stuff

In [None]:
%pip install sagemaker --upgrade  --quiet

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

## Step 2: Start preparing model artifacts
In LMI contianer, we expect some artifacts to help setting up the model
- serving.properties (required): Defines the model server settings
- model.py (optional): A python file to define the core inference logic
- requirements.txt (optional): Any additional pip wheel need to install

In [None]:
%%writefile serving.properties
engine=MPI
option.model_id=TheBloke/Llama-2-7B-Chat-fp16
option.task=text-generation
option.trust_remote_code=true
option.tensor_parallel_degree=1
option.max_rolling_batch_size=32
option.rolling_batch=lmi-dist
option.dtype=fp16

In this step, we will try to override the [default HuggingFace handler](https://github.com/deepjavalibrary/djl-serving/blob/0.25.0-dlc/engines/python/setup/djl_python/huggingface.py) provided by DJLServing. We will add an extra parameter checker called `password` to see if password is correct in the payload.

In [None]:
%%writefile model.py
from djl_python.huggingface import HuggingFaceService
from djl_python import Output
from djl_python.encode_decode import encode, decode
from transformers import AutoTokenizer
import logging
import json
import types

_service = HuggingFaceService()

def custom_parse_input(self, inputs):
    input_data = []
    input_size = []
    parameters = []
    errors = {}
    # used for chat completion
    if self.tokenizer is None:
        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_configs.model_id_or_path)
    batch = inputs.get_batches()
    for i, item in enumerate(batch):
        try:
            content_type = item.get_property("Content-Type")
            input_map = decode(item, content_type)
        except Exception as e:  # pylint: disable=broad-except
            logging.warning(f"Parse input failed: {i}")
            input_size.append(0)
            errors[i] = str(e)
            continue
        # Chat message masssaging
        chat = input_map.pop("chat", [])
        if len(chat) != 0:
            formatted_str = self.tokenizer.apply_chat_template(chat, tokenize=False)
            input_data.extend([formatted_str])
        else:
            input_data.extend([""])
        input_size.append(1)
        # End of massaging
        _param = input_map.pop("parameters", {})
        if not "seed" in _param:
            # set server provided seed if seed is not part of request
            if item.contains_key("seed"):
                _param["seed"] = item.get_as_string(key="seed")
        for _ in range(input_size[i]):
            parameters.append(_param)

    return input_data, input_size, parameters, errors, batch


def chat_output_formatter(token, first_token, last_token, details, generated_tokens):
    """
    json output formatter

    :return: formatted output
    """
    json_encoded_str = f"{{\"role\": \"assistant\", \"content\": \"" if first_token else ""
    json_encoded_str = f"{json_encoded_str}{json.dumps(token.text, ensure_ascii=False)[1:-1]}"
    if last_token:
        if details:
            details_str = f"\"details\": {json.dumps(details, ensure_ascii=False)}"
            json_encoded_str = f"{json_encoded_str}\", {details_str}}}"
        else:
            json_encoded_str = f"{json_encoded_str}\"}}"

    return json_encoded_str


def handle(inputs):
    if not _service.initialized:
        props = inputs.get_properties()
        props["output_formatter"] = chat_output_formatter
        _service.initialize(inputs.get_properties())
        # replace parse_input
        _service.parse_input = types.MethodType(custom_parse_input, _service)

    if inputs.is_empty():
        # initialization request
        return None

    return _service.inference(inputs)

In [None]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
mv model.py mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

## Step 3: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Getting the container image URI

[Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)


In [None]:
image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.26.0"
    )

### Upload artifact on S3 and create SageMaker model

In [None]:
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

### 4.2 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [None]:
instance_type = "ml.g5.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer()
)

## Step 5: Inference

In [None]:
%time 

chat = [
  {"role": "user", "content": "Tell me a story about Albert Einstein "},
]

# NOTE: If you do not want the log probability to be printed in your output, please set, "do_sample":False

result = predictor.predict(
    {"chat": chat, "parameters": {"do_sample": True, "max_new_tokens": 600}}
)

In [None]:
assistant_response = {"role": result["role"], "content": result["content"]}

print(f'{result["role"]}: {result["content"]}')

chat.append(assistant_response)
# next
user_input = {"role": "user", "content": "Tell me a joke."}
print(f'{user_input["role"]}: {user_input["content"]}')
chat.append(user_input)
result = predictor.predict(
    {"chat": chat, "parameters": {"do_sample": True, "details": True, "max_new_tokens": 512}}
)
print(f'{result["role"]}: {result["content"]}')


## Clean up the environment

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()