# 1. 升级python SDK

In [None]:
!pip install --upgrade boto3
!pip install --upgrade sagemaker

# 2. 获取Runtime资源配置

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role

sess                     = sagemaker.Session()
role                     = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()

account                  = sess.boto_session.client("sts").get_caller_identity()["Account"]
region                   = sess.boto_session.region_name

# 3. 准备Dummy模型

In [None]:
!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(sagemaker_default_bucket, 'chatglm')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(sagemaker_default_bucket, 'chatglm')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

# 4. 配置模型参数

In [None]:
model_name        = None
entry_point       = 'chatglm-inference.py'
framework_version = '1.13.1'
py_version        = 'py39'
model_environment = {
    'SAGEMAKER_MODEL_SERVER_TIMEOUT': '600', 
    'SAGEMAKER_MODEL_SERVER_WORKERS': '1', 
}

In [None]:
from sagemaker.pytorch.model import PyTorchModel

model = PyTorchModel(
    name              = model_name,
    model_data        = model_data,
    entry_point       = entry_point,
    source_dir        = './code',
    role              = role,
    framework_version = framework_version, 
    py_version        = py_version,
    env               = model_environment
)

# 5. 部署ChatGLM模型

In [None]:

from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

endpoint_name         = None
instance_type         = 'ml.g4dn.2xlarge'
instance_count        = 1

predictor = model.deploy(
    endpoint_name          = endpoint_name,
    instance_type          = instance_type, 
    initial_instance_count = instance_count,
    serializer             = JSONSerializer(),
    deserializer           = JSONDeserializer()
)

# 6. 测试ChatGLM模型推理

In [None]:
inputs = {
    "ask": "你好!"

}

response = predictor.predict(inputs)
print(response["answer"])


In [None]:
inputs = {
    "ask": "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞"

}

response = predictor.predict(inputs)
print(response["answer"])
