# Elyza Japanese Llama2 TGI SageMaker Inference

This is a sample code to deploy `elyza/ELYZA-japanese-Llama-2-7b-instruct` using text-generation-inference (TGI) on SageMaker.

In [None]:
%pip install sagemaker pip boto3 botocore --upgrade  --quiet

In [None]:
from sagemaker.huggingface import get_huggingface_llm_image_uri
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.session import Session
import sagemaker

sagemaker_session = Session()
role = sagemaker_session.get_caller_identity_arn()

## Deploy Model

In [None]:
hf_model_id = "elyza/ELYZA-japanese-Llama-2-7b-instruct" # model id from huggingface.co/models
number_of_gpus = 1 # number of gpus to use for inference and tensor parallelism
health_check_timeout = 300 # Increase the timeout for the health check to 5 minutes for downloading the model
instance_type = "ml.g5.2xlarge" # instance type to use for deployment

In [None]:
llm_image = get_huggingface_llm_image_uri(
    "huggingface",
    version="0.9.3"
)
endpoint_name = sagemaker.utils.name_from_base("elyza-7b-inference")
llm_model = HuggingFaceModel(
    role=role,
    image_uri=llm_image,
    env={
        'HF_MODEL_ID': hf_model_id,
        # 'REVISION': '2140541486bfb31269acd035edd51208da40185b',
        # 'HF_MODEL_QUANTIZE': "bitsandbytes", # comment in to quantize
        'SM_NUM_GPUS': str(number_of_gpus),
        'DTYPE': 'bfloat16',
        'MAX_INPUT_LENGTH': "2048",  # Max length of input text
        'MAX_TOTAL_TOKENS': "4096",  # Max length of the generation (including input text)
        'MAX_BATCH_TOTAL_TOKENS': "8192",
    }
)
llm = llm_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,
    endpoint_name=endpoint_name,
)

## Run Inference

In [None]:
import json
import boto3
import logging
import io

boto3.set_stream_logger("",logging.INFO)
smr = boto3.client('sagemaker-runtime')

endpoint_name = llm.endpoint_name


class LineIterator:
    """
    A helper class for parsing the byte stream input from TGI container. 
    
    The output of the model will be in the following format:
    ```
    b'data:{"token": {"text": " a"}}\n\n'
    b'data:{"token": {"text": " challenging"}}\n\n'
    b'data:{"token": {"text": " problem"
    b'}}'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. It will also save any pending 
    lines that doe not end with a '\n' to make sure truncations are concatinated
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])


            
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
stop_token = '</s>'

def inference(text, system="あなたは誠実で優秀な日本人のアシスタントです。"):
    prompt = "{bos_token}{b_inst} {system}{prompt} {e_inst} ".format(
        bos_token="<s>",
        b_inst=B_INST,
        system=f"{B_SYS}{system}{E_SYS}",
        prompt=text,
        e_inst=E_INST,
    )
    body = {
        "inputs":prompt,
        "parameters":{
            "max_new_tokens": 512,
            "return_full_text": False,
            "do_sample": True,
            "temperature": 0.3,
            "stop": [stop_token]
        },
    }
    response = smr.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Accept='application/json',
        Body=json.dumps(body)
    )
    print(json.loads(response['Body'].read())[0]['generated_text'])


def inference_stream(text, system="あなたは誠実で優秀な日本人のアシスタントです。"):
    prompt = "{bos_token}{b_inst} {system}{prompt} {e_inst} ".format(
        bos_token="<s>",
        b_inst=B_INST,
        system=f"{B_SYS}{system}{E_SYS}",
        prompt=text,
        e_inst=E_INST,
    )
    body = {
        "inputs":prompt,
        "parameters":{
            "max_new_tokens": 512,
            "return_full_text": False,
            "do_sample": True,
            "temperature": 0.3,
            "stop": [stop_token]
        },
        "stream": True
    }
    resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType='application/json')
    # print(resp)
    event_stream = resp['Body']
    start_json = b'{'
    for line in LineIterator(event_stream):
        # print(line)
        if line != b'' and start_json in line:
            data = json.loads(line[line.find(start_json):].decode('utf-8'))
            if not stop_token in data['token']['text']:
                print(data['token']['text'],end='')

In [None]:
inference_stream("AWSとはなんですか？一言で要約してください")

## Delete Endpoint

In [None]:
llm.delete_model()
llm.delete_endpoint()