## 준비

based on https://github.com/triton-inference-server/fastertransformer_backend/tree/dev/v1.1_beta

* 모델 생성과 local mode 테스트를 위해 ml.p3.16xlarge 노트북 인스턴스에서 작업
* 이미지 크기가 크므로 노트북 생성시 디스크 용량 증가 및 docker image 경로 변경 필요
* fastertransformer_backend README 참고하여 git clone(fastertransformer_backend, triton, FasterTransformer)

SageMaker Triton image pull(us-east-1 기준)
* ECR 로긴 필요

In [None]:
!docker pull 785573368785.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tritonserver:21.08-py3

Image build & push
* [Dockerfile](https://github.com/triton-inference-server/fastertransformer_backend/blob/dev/v1.1_beta/docker/Dockerfile)에서 Base Image를 SageMaker Triton 이미지로 교체하고 마지막에 serve 파일을 대체.
* serve 파일은 [원본](https://github.com/triton-inference-server/server/blob/main/docker/sagemaker/serve)에서 마지막 실행 명령만 faster transformer 백엔드의 실행 명령을 참고하여 수정했음.
* 원래 dockerfile이 있는 경로(workspace/fastertransformer_backend/docker)에 Dockerfile.sm을 붙여넣고,
* 상위 폴더(workspace/fastertransformer_backend)에 serve파일 붙여 넣은 후 docker build(터미널에서)

```
docker build -t {account_number}.dkr.ecr.us-east-1.amazonaws.com/sm-triton-ft:21.08-py3 -f docker/Dockerfile.sm .
```

* Push 전에 ECR 레포지토리 sm-triton-ft 생성, ECR 로긴, push 권한 설정 필요

In [None]:
!docker push {account_number}.dkr.ecr.us-east-1.amazonaws.com/sm-triton-ft:21.08-py3

## Model 생성 및 S3 업로드

모델 생성
* [fastertransformer_backend README.md How to set the model configuration](https://github.com/triton-inference-server/fastertransformer_backend/tree/dev/v1.1_beta#how-to-set-the-model-configuration) 참고 Prepare Triton GPT model store

config.pbtxt 수정
* tensor_para_size = 8
* model_checkpoint_path = "/opt/ml/model/fastertransformer/1/8-gpu"

모델 압축 및 S3 업로드

In [None]:
!wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P models
!wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P models
!wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip
!mkdir -p ./models/megatron-models/345m
!unzip megatron_lm_345m_v0.0.zip -d models/megatron-models/345m

In [None]:
!mkdir -p triton-serve-ft/fastertransformer/
!cp -r ./fastertransformer_backend/all_models/ triton-serve-ft/fastertransformer/1/
!cp config.pbtxt triton-serve-ft/fastertransformer
!tar -C triton-serve-ft/ -czf model.tar.gz fastertransformer

In [None]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role

sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")

In [None]:
model_uri = sagemaker_session.upload_data(path="model.tar.gz", key_prefix="triton-serve-ft")
image_uri = "{account_number}.dkr.ecr.us-east-1.amazonaws.com/sm-triton-ft:21.08-py3"

## Local mode test

* [fastertransformer_backend Run Serving on Single Node](https://github.com/triton-inference-server/fastertransformer_backend/tree/dev/v1.1_beta#run-serving-on-single-node) 참고
* [SageMaker Triton example](https://github.com/aws/amazon-sagemaker-examples/blob/1072934944e5270f7f2fb0d9e0e1a86ce96aa57e/sagemaker-triton/nlp_bert/triton_nlp_bert.ipynb) 참고

In [None]:
sm_model_name = "triton-ft-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
model = sagemaker.model.Model(image_uri=image_uri, model_data=model_uri, role=role, 
                              name=sm_model_name)
model.deploy(initial_instance_count=1, instance_type='local_gpu')

In [None]:
!pip install tritonclient[http]

In [None]:
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.local import LocalSession

local_sess = LocalSession()
predictor = Predictor(
    endpoint_name=model.endpoint_name, 
    sagemaker_session=local_sess,
    serializer=JSONSerializer()
)

In [None]:
import tritonclient.http as httpclient
import numpy as np
from tritonclientutils import np_to_triton_dtype, InferenceServerException

input_start_ids = np.array([
    [9915, 27221, 59, 77, 383, 1853, 3327, 1462],
    [6601, 4237, 345, 460, 779, 284, 787, 257],
    [59, 77, 611, 7, 9248, 796, 657, 8],
    [38, 10128, 6032, 651, 8699, 4, 4048, 20753],
    [21448, 7006, 930, 12901, 930, 7406, 7006, 198],
    [13256, 11, 281, 1605, 3370, 11, 1444, 6771],
    [9915, 27221, 59, 77, 383, 1853, 3327, 1462],
    [6601, 4237, 345, 460, 779, 284, 787, 257]
], np.uint32)
input_start_ids = input_start_ids.reshape([input_start_ids.shape[0], 1, input_start_ids.shape[1]])
input_data = np.tile(input_start_ids, (1, 1, 1))
input_len = np.array([[sentence.size] for sentence in input_start_ids], np.uint32)
output_len = np.ones_like(input_len).astype(np.uint32) * 24

payload = {
    "inputs": [
        {"name": "INPUT_ID", "shape": input_data.shape, "datatype": np_to_triton_dtype(input_data.dtype), 
         "data": input_data.tolist()},
        {"name": "REQUEST_INPUT_LEN", "shape": input_len.shape, "datatype": np_to_triton_dtype(input_len.dtype), 
         "data": input_len.tolist()},
        {"name": "REQUEST_OUTPUT_LEN", "shape": output_len.shape, "datatype": np_to_triton_dtype(output_len.dtype),
         "data": output_len.tolist()}
    ]
}
request_parallelism = 100
for i in range(request_parallelism):
    predictor.predict(payload)

## 엔드포인트, 모델 삭제

In [None]:
# predictor.delete_endpoint()
# model.delete_model()

## Endpoint 생성 및 테스트

In [None]:
sm_model_name = "triton-ft-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": image_uri,
    "ModelDataUrl": model_uri,
}

create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
endpoint_config_name = "triton-ft-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.p3.16xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

In [None]:
endpoint_name = "triton-ft-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
import tritonclient.http as httpclient
import numpy as np

input_start_ids = np.array([
    [9915, 27221, 59, 77, 383, 1853, 3327, 1462],
    [6601, 4237, 345, 460, 779, 284, 787, 257],
    [59, 77, 611, 7, 9248, 796, 657, 8],
    [38, 10128, 6032, 651, 8699, 4, 4048, 20753],
    [21448, 7006, 930, 12901, 930, 7406, 7006, 198],
    [13256, 11, 281, 1605, 3370, 11, 1444, 6771],
    [9915, 27221, 59, 77, 383, 1853, 3327, 1462],
    [6601, 4237, 345, 460, 779, 284, 787, 257]
], np.uint32)
input_start_ids = input_start_ids.reshape([input_start_ids.shape[0], 1, input_start_ids.shape[1]])
input_data = np.tile(input_start_ids, (1, 1, 1))
input_len = np.array([[sentence.size] for sentence in input_start_ids], np.uint32)
output_len = np.ones_like(input_len).astype(np.uint32) * 24

from datetime import datetime
request_parallelism = 10


start_time = datetime.now()

stop_time = datetime.now()
latency = ((stop_time - start_time).total_seconds()* 1000.0 / request_parallelism)

for i in range(request_parallelism):
    payload = {
        "inputs": [
            {"name": "INPUT_ID", "shape": input_data.shape, "datatype": np_to_triton_dtype(input_data.dtype), 
             "data": input_data.tolist()},
            {"name": "REQUEST_INPUT_LEN", "shape": input_len.shape, "datatype": np_to_triton_dtype(input_len.dtype), 
             "data": input_len.tolist()},
            {"name": "REQUEST_OUTPUT_LEN", "shape": output_len.shape, "datatype": np_to_triton_dtype(output_len.dtype),
             "data": output_len.tolist()}
        ]
    }

    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/octet-stream", Body=json.dumps(payload)
    )

    print(json.loads(response["Body"].read().decode("utf8")))
print(f"[INFO] execution time: {latency} ms") 

## Clean up

In [None]:
sm.delete_model(ModelName=sm_model_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_endpoint(EndpointName=endpoint_name)