In [None]:
import json

import sagemaker
from sagemaker.huggingface import HuggingFaceModel
 
sess = sagemaker.Session()
sm_role = sagemaker.get_execution_role()

llm_image = f"763104351884.dkr.ecr.{sess.boto_region_name}.amazonaws.com/huggingface-pytorch-tgi-inference:2.1-tgi2.0-gpu-py310-cu121-ubuntu22.04"
 
# print ecr image uri
print(f"llm image uri: {llm_image}")

# sagemaker config
instance_type = "ml.g5.2xlarge"
health_check_timeout = 900
 
# Define Model and Endpoint configuration parameter
config = {
  'HF_MODEL_ID':'/opt/ml/model',
  'MAX_INPUT_LENGTH': "2048",  # Max length of input text
  'MAX_TOTAL_TOKENS': "4096",  # Max length of the generation (including input text)
  'MAX_BATCH_TOTAL_TOKENS': "8192",  # Limits the number of tokens that can be processed in parallel during the generation
  'MESSAGES_API_ENABLED': "true", # Enable the messages API
}
 
# check if token is set
 
# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(
  model_data='s3://YOUR_BUCKET/llama-chinese.tar.gz',
  role=sm_role,
  image_uri=llm_image,
  env=config
)

# Deploy model to an endpoint
# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy
llm = llm_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  container_startup_health_check_timeout=health_check_timeout, # 10 minutes to be able to load the model
)

In [None]:
# Prompt to generate
messages=[
    { "role": "system", "content": "你是个智能AI，小心回答用户的脑筋急转弯问题" },
    { "role": "user", "content": "我的蓝牙耳机坏了，我该去看牙科还是耳鼻喉科？" }
  ]
 
# Generation arguments
parameters = {
    "model": "meta-llama/Meta-Llama-3-8B-Instruct", # placholder, needed
    "top_p": 0.8,
    "temperature": 0.9,
    "max_tokens": 512,
    "stop": ["<|eot_id|>"],
}

chat = llm.predict({"messages" :messages, **parameters})
 
print(chat["choices"][0]["message"]["content"].strip())