Llama2-7b Model with response streaming (using TGI container)

In [None]:
HUGGING_FACE_HUB_TOKEN="<your Huggingface read token goes here>"

In [None]:
from sagemaker.model import Model
from sagemaker import get_execution_role
from sagemaker.huggingface import get_huggingface_llm_image_uri
from sagemaker.huggingface import HuggingFaceModel
# retrieve the llm image uri
llm_image = get_huggingface_llm_image_uri(
  "huggingface",
  version="1.1.0"
)

role = get_execution_role()
hf_model_id = "meta-llama/Llama-2-7b-chat-hf" # model id from huggingface.co/models
model_name = hf_model_id.replace("/","-").replace(".","-")
endpoint_name = "Llama-2-7b-chat-hf-endpoint"
instance_type = "ml.g5.2xlarge" # instance type to use for deployment
number_of_gpus = 1 # number of gpus to use for inference and tensor parallelism
health_check_timeout = 900 # Increase the timeout for the health check to 5 minutes for downloading the model

llm_model = HuggingFaceModel(
      role=role,
      image_uri=llm_image,
      env={
        'HF_MODEL_ID': hf_model_id,
        'HUGGING_FACE_HUB_TOKEN': HUGGING_FACE_HUB_TOKEN,
        'SM_NUM_GPUS': f"{number_of_gpus}"
      },
      name=model_name
    )

llm = llm_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  container_startup_health_check_timeout=health_check_timeout,
  endpoint_name=endpoint_name,
)

In [None]:
import io

class StreamIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == 10:
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print(f"Unknown event type: {chunk}")
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [None]:
import boto3
import json

endpoint_name="Llama-2-7b-chat-hf-endpoint"

message = "Write a short poem about Sydney in Australia"

system_message = "You are a priate"

Roles = ["<s>[INST]", "[\INST]"]

prompt = f"{Roles[0]} <<SYS>>\n{system_message}\n<</SYS>>\n\n{message} {Roles[1]}"

smr = boto3.client("sagemaker-runtime")
special = False
data = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 1024,
        "stop": ["</s>"],
    },
    "stream": True
}

res = smr.invoke_endpoint_with_response_stream(
    Body=json.dumps(data),
    EndpointName=endpoint_name,
    ContentType="application/json"
)

text = ""
for chunk in StreamIterator(res["Body"]):
    if chunk:
        # print(chunk)
        special = json.loads(chunk[5:])["token"]["special"]
        text += json.loads(chunk[5:])["token"]["text"]
        if not special:
            # text += json.loads(chunk[5:])["token"]["text"]
            print(json.loads(chunk[5:])["token"]["text"], end="")

prompt += text