# Deploy ChatGLM-v2 on SageMaker

#### Install and upgrade dependencies

In [None]:
%pip install --upgrade pip --quiet
%pip install sagemaker boto3 awscli --upgrade --quiet

In [None]:
# init sagemaker parameters
import boto3
import sagemaker
from sagemaker import Model, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

s3_code_prefix = "east-ai-models/chatglm2/accelerate"

In [None]:
!mkdir mymodel

#### Writing SageMaker LMI code properties and model.py

In [None]:
%%writefile ./mymodel/serving.properties
engine=Python
option.tensor_parallel_degree=1
option.enable_streaming=True
option.predict_timeout=240
option.model_id=THUDM/chatglm2-6b

In [None]:
%%writefile ./mymodel/requirements.txt
transformers==4.30.2

In [None]:
%%writefile ./mymodel/model.py
from djl_python import Input, Output
import os
import torch
from transformers import AutoTokenizer, AutoModel
from typing import Any, Dict, Tuple
import warnings
import json

model = None
tokenizer = None


def get_model(properties):
    model_name = properties["model_id"]
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    return model, tokenizer


def stream_items(prompt, history, max_length, top_p, temperature):
    global model, tokenizer
    size = 0
    response = ""
    for response, history in model.stream_chat(tokenizer, prompt, history=history, max_length=max_length, top_p=top_p,
                                               temperature=temperature):
        this_response = response[size:]
        history = [list(h) for h in history]
        size = len(response)
        stream_buffer = { "outputs":this_response, "history":history}
        yield stream_buffer


def handle(inputs: Input) -> None:
    global model, tokenizer
    print("print inputs: " + str(inputs) + '.'*20)
    if not model:
        model, tokenizer = get_model(inputs.get_properties())

    if inputs.is_empty():
        # Model server makes an empty call to warmup the model on startup
        return None
    input_map = inputs.get_as_json()
    data = input_map.pop("inputs", input_map)
    params = input_map.pop("parameters", {})
    history = input_map["history"]
    print("print data: " + str(data) + '.'*20)
    model = model.eval()
    outputs = Output()
    outputs.add_property("content-type", "application/jsonlines")
    outputs.add_stream_content(stream_items(data, history=history, **params))
    print("Start streaming output" + "."*20)
    return outputs

In [None]:
# compress code and upload to S3
!rm -f model.tar.gz
!rm -rf mymodel/.ipynb_checkpoints
!tar czvf model.tar.gz -C mymodel .
s3_code_artifact = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar uploaded to --- > {s3_code_artifact}")

#### Model deployment

In [None]:
# retrieve SageMaker LMI container image URI
image_uri = sagemaker.image_uris.retrieve(
    framework="djl-deepspeed", region=region, version="0.23.0"
)


print(image_uri)

model = Model(image_uri=image_uri, model_data=s3_code_artifact, role=role)

In [None]:
instance_type = "ml.g5.2xlarge"  # "ml.g5.2xlarge" - #single GPU. really need one GPU for this since tensor split is '1'

endpoint_name = "chatglm2-lmi-model"

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=900,
)

#### Prediction

In [None]:
import io

class StreamScanner:
    """
    A helper class for parsing the InvokeEndpointWithResponseStream event stream. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'readlines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self):
        self.buff = io.BytesIO()
        self.read_pos = 0
        
    def write(self, content):
        self.buff.seek(0, io.SEEK_END)
        self.buff.write(content)
        
    def readlines(self):
        self.buff.seek(self.read_pos)
        for line in self.buff.readlines():
            if line[-1] != b'\n':
                self.read_pos += len(line)
                yield line[:-1]
                
    def reset(self):
        self.read_pos = 0

In [None]:
import boto3
import json

smr = boto3.client('sagemaker-runtime')

parameters = {
  "max_length": 4092,
  "temperature": 0.01,
  "top_p":0.8
}

response_model = smr.invoke_endpoint_with_response_stream(
            EndpointName=endpoint_name,
            Body=json.dumps(
            {
                "inputs": """推荐几个适合度假的地方""",
                "parameters": parameters,
                "history" : []
            }
            ),
            ContentType="application/json",
        )

event_stream = response_model['Body']
scanner = StreamScanner()
for event in event_stream:
    scanner.write(event['PayloadPart']['Bytes'])
    for line in scanner.readlines():
        try:
            resp = json.loads(line)['outputs']['outputs']
            print(resp, end='')
            # print(resp.get("outputs")['outputs'], end='')
        except Exception as e:
            # print(line)
            continue