In [None]:
%pip install --upgrade pip --quiet
%pip install sagemaker boto3 awscli --upgrade --quiet

In [None]:
import boto3
import sagemaker
from sagemaker import Model, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

s3_code_prefix_deepspeed = "east-ai-models/llama2-chinese-7b-chat/deepspeed"

In [None]:
print(role, region, account_id)

In [None]:
!mkdir mymodel

In [None]:
%%writefile ./mymodel/serving.properties
engine=DeepSpeed
option.model_id=FlagAlpha/Llama2-Chinese-7b-Chat
option.tensor_parallel_degree=1

In [None]:
%%writefile ./mymodel/model.py
from djl_python import Input, Output
import os
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers import LlamaTokenizer, LlamaForCausalLM
from typing import Any, Dict, Tuple
import deepspeed
import warnings
import json

model = None
tokenizer = None


def get_model(properties):
    model_name = properties["model_id"]
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    model = LlamaForCausalLM.from_pretrained(
        model_name, low_cpu_mem_usage=True, trust_remote_code=True, torch_dtype=torch.bfloat16
    )
    model = deepspeed.init_inference(model, mp_size=properties["tensor_parallel_degree"])
    tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True)
    return model, tokenizer


def handle(inputs: Input) -> None:
    global model, tokenizer
    print("print inputs: " + str(inputs) + '.'*20)
    if not model:
        model, tokenizer = get_model(inputs.get_properties())

    if inputs.is_empty():
        # Model server makes an empty call to warmup the model on startup
        return None
    input_map = inputs.get_as_json()
    data = input_map.pop("inputs", input_map)
    parameters = input_map.pop("parameters", {})
    print("print data: " + str(data) + '.'*20)
    input_tokens = tokenizer(data, return_tensors="pt").to(
            torch.cuda.current_device()
        )
    with torch.no_grad():
        output_tokens = model.generate(input_tokens.input_ids, **parameters)
    generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    print("print generated_text: " + generated_text + '.'*20)
    out = {'generated_text': generated_text}
    return Output().add(json.dumps(out))

In [None]:
!rm -f model.tar.gz
!rm -rf mymodel/.ipynb_checkpoints
!tar czvf model.tar.gz -C mymodel .
s3_code_artifact_deepspeed = sess.upload_data("model.tar.gz", bucket, s3_code_prefix_deepspeed)
print(f"S3 Code or Model tar uploaded to --- > {s3_code_artifact_deepspeed}")

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="djl-deepspeed", region=region, version="0.23.0"
)


print(image_uri)

model = Model(image_uri=image_uri, model_data=s3_code_artifact_deepspeed, role=role)

In [None]:
instance_type = "ml.g5.2xlarge"  # "ml.g5.2xlarge" - #single GPU. really need one GPU for this since tensor split is '1'

endpoint_name = sagemaker.utils.name_from_base("llama2-chinese-7b-chat-lmi-model")

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=900,
)

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

In [None]:
import time

tic = time.time()
res = predictor.predict(
    {"inputs": "你好，睡眠不好怎么办？", "parameters": {"max_new_tokens": 256}}
)
toc = time.time()
print(res)
print(toc - tic)