In [1]:
!pip install --upgrade boto3
!pip install --upgrade sagemaker

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com


In [7]:
import boto3
import sagemaker

account_id = boto3.client('sts').get_caller_identity().get('Account')
region_name = boto3.session.Session().region_name

sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

print(role)
print(bucket)

if "cn-" in region_name:
    with open('./code/requirements.txt', 'r') as original: data = original.read()
    with open('./code/requirements.txt', 'w') as modified: modified.write("-i https://pypi.tuna.tsinghua.edu.cn/simple\n" + data)

!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'eb_chinese')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'eb_chinese')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

model_name = None
entry_point = 'inference.py'
framework_version = '1.13.1'
py_version = 'py39'
model_environment = {
    'SAGEMAKER_MODEL_SERVER_TIMEOUT':'600', 
    'SAGEMAKER_MODEL_SERVER_WORKERS': '1', 
}

from sagemaker.pytorch.model import PyTorchModel

model = PyTorchModel(
    name = model_name,
    model_data = model_data,
    entry_point = entry_point,
    source_dir = './code',
    role = role,
    framework_version = framework_version, 
    py_version = py_version,
    env = model_environment
)

endpoint_name = 'huggingface-inference-eb'
instance_type = 'ml.m5.2xlarge'
# instance_type='ml.g4dn.2xlarge' 

instance_count = 1

from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
predictor = model.deploy(
    endpoint_name = endpoint_name,
    instance_type = instance_type, 
    initial_instance_count = instance_count,
    serializer = JSONSerializer(),
    deserializer = JSONDeserializer()
)

arn:aws-cn:iam::472964545242:role/service-role/AmazonSageMaker-ExecutionRole-20230614T115763
sagemaker-cn-north-1-472964545242
dummy
upload: ./model.tar.gz to s3://sagemaker-cn-north-1-472964545242/eb_chinese/assets/model.tar.gz
------------!

In [None]:
#休眠1分钟,确保模型可以完全加载
import time
time.sleep(60)

In [8]:
# Inference testing
data = {
  "inputs": "你好。",
}

res = predictor.predict(data)
print(res)

[[[-0.04677451401948929, 0.12917593121528625, 0.7252687811851501, 0.6887288093566895, 0.6720475554466248, -0.6931467056274414, 0.6111407279968262, 0.045531265437603, -0.7269701957702637, 0.6346073150634766, -0.011951988562941551, 0.5081378817558289, 0.06084213778376579, -0.36596569418907166, 0.0296473391354084, -0.3710579574108124, -0.6186922788619995, 0.023741114884614944, -0.24080410599708557, -0.5034646987915039, -0.26080578565597534, 0.49112120270729065, -0.6039136648178101, 0.5322360992431641, 0.23669886589050293, 0.22232456505298615, -0.3281732499599457, -0.06767880916595459, 1.2170698642730713, 0.510742723941803, -0.1280217468738556, 0.7181637287139893, -0.21124252676963806, 0.014374107122421265, -1.0576074123382568, 0.25009575486183167, 0.4328520894050598, 0.7128980755805969, -0.51764976978302, 0.4310057759284973, 0.33416396379470825, 0.08634588122367859, -1.109205961227417, 0.25449997186660767, -0.08281388133764267, -0.3182655870914459, 0.7469677329063416, 1.5996872186660767, 

In [10]:
# Inference testing
import time

hfp = sagemaker.huggingface.model.HuggingFacePredictor('huggingface-inference-eb')

t0 = time.time()
for i in range(1):
    x = hfp.predict({'inputs':''.join(['打印' for _ in range(100)])})
print(time.time()-t0)

print(len(x[0][0]))
print(x)

0.21089696884155273
768
[[[-0.12084222584962845, -0.13449691236019135, 0.3932150900363922, 0.16505874693393707, 0.07772558927536011, -1.3287967443466187, 0.2512413561344147, 0.357142835855484, -1.6107465028762817, -0.3020375072956085, 0.8068588376045227, -0.5724591016769409, -0.2644294798374176, 0.25229570269584656, -0.46913301944732666, 1.034051537513733, -0.24374452233314514, -0.32981857657432556, 0.0024372439365833998, -0.3498744070529938, -0.5836201906204224, 0.8180524110794067, -0.5290428400039673, 0.13322144746780396, 0.7637916207313538, 0.5445929765701294, -0.9730080962181091, 0.44369974732398987, 0.06250642240047455, 1.091039776802063, 0.6685947179794312, -0.7008579969406128, -0.6584241390228271, -0.07846096158027649, -0.08940890431404114, 1.393285870552063, -0.08004172146320343, 0.34632283449172974, 1.0319600105285645, -0.07839339226484299, -0.8218563795089722, 1.2033993005752563, 0.5746997594833374, -0.6333560943603516, -0.18563978374004364, -0.5973430871963501, 0.70659530162