In [None]:
%pip install -U -q sagemaker

In [13]:
import os
import boto3
import shutil
import sagemaker

print(sagemaker.__version__)
if not sagemaker.__version__ >= "2.146.0": print("You need to upgrade or restart the kernel if you already upgraded")

sess = sagemaker.Session()
role = sagemaker.get_execution_role()
bucket = sess.default_bucket()
region = sess.boto_region_name

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {bucket}")
print(f"sagemaker session region: {region}")

2.221.1
sagemaker role arn: arn:aws:iam::236241703319:role/service-role/AmazonSageMaker-ExecutionRole-20240525T225318
sagemaker bucket: sagemaker-ap-northeast-2-236241703319
sagemaker session region: ap-northeast-2


In [24]:
from sagemaker import image_uris
image_uri = image_uris.retrieve(framework='inferentia-pytorch',region='ap-northeast-2',version='1.9',py_version='py3')

image_uri

'151534178276.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-neo-pytorch:1.9-inf-py3'

In [25]:
import logging
from sagemaker.utils import name_from_base
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.predictor import Predictor
from datetime import datetime

hardware = "inf1"
date_string = datetime.now().strftime("%Y%m-%d%H-%M%S")
model_data = "s3://sagemaker-ap-northeast-2-236241703319/neuron-experiments/bge-m3/compiled-model/compiled_model_v2.tar.gz"

pytorch_model = PyTorchModel(    
    image_uri=image_uri,
    model_data=model_data,
    role=role,
    name=name_from_base('bge-m3'),
    sagemaker_session=sess,
    container_log_level=logging.WARN,
    framework_version="1.13.1",
    # model_server_workers=4, # 1 worker per core
    env = {
        'SAGEMAKER_MODEL_SERVER_TIMEOUT' : '7200' 
    }
)
pytorch_model._is_compiled_model = True

In [26]:
%%time
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

predictor = pytorch_model.deploy(
    initial_instance_count=1,
    instance_type="ml.inf1.6xlarge",
    endpoint_name=f"bge-m3-{hardware}-{date_string}-v2",
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

print(f"Endpoint Name: bge-m3-{hardware}-{date_string}")

-------------!Endpoint Name: bge-m3-inf1-202405-2805-4315
CPU times: user 87.1 ms, sys: 4.03 ms, total: 91.1 ms
Wall time: 7min 2s


In [29]:
seq_0 = "Hi, global"
seq_1 = "Hello, world"

# Send a payload to the endpoint and recieve the inference
payload = seq_0, seq_1
outputs = predictor.predict(payload)
outputs

'Similaritiy of two sentences are 0.855'