## BART Model Hosting With DJL Serving on Amazon SageMaker Real-Time Inference

Setting: conda_amazonei_pytorch_latest_p37 Kernel & ml.c5.9xlarge Classic Notebook Instance

### Local Sample Inference

In [None]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
model = AutoModel.from_pretrained("facebook/bart-large")

In [None]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state
last_hidden_states

### DJL Specific Artifacts

For DJL Serving there are three artifacts we need to encapsulate in our model tarball

- model.py: Your pre/post processing logic as well as model inference, you can add any customization in this script.
- requirements.txt: Any other libraries or packages you utilize in your model.py
- serving.properties: We define the engine and different configurations for DJL Serving, these are the environment variables that your model.py script can parse as well (captured in 'properties' object).

In [None]:
%%writefile model.py

import logging
import time
import os
from djl_python import Input
from djl_python import Output
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np


class BartModel(object):
    """
    Deploying Bart with DJL Serving
    """

    def __init__(self):
        self.initialized = False

    def initialize(self, properties: dict):
        """
        Initialize model.
        """
        print(os.listdir())
        logging.info("-----------------")
        logging.info(properties)
        
        tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
        model = AutoModel.from_pretrained("facebook/bart-large")
        
        self.model_name = properties.get("model_id")
        self.task = properties.get("task")
        logging.info("-----------------")
        logging.info(self.model_name)
        logging.info("-----------------")
        logging.info(self.task)
        self.model = AutoModel.from_pretrained(self.model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.initialized = True

    def inference(self, inputs):
        """
        Custom service entry point function.

        :param inputs: the Input object holds the text for the BART model to infer upon
        :return: the Output object to be send back
        """

        #sample input: "This is the sample text that I am passing in"
        
        try:
            data = inputs.get_as_string()
            logging.info("-----------------")
            logging.info(data)
            logging.info(type(data))
            logging.info("-----------------")
            inputs = self.tokenizer(data, return_tensors="pt")
            preds = self.model(**inputs)
            logging.info("-----------------")
            logging.info(type(preds))
            logging.info("-----------------")
            res = preds.last_hidden_state.detach().cpu().numpy().tolist() #convert to JSON Serializable object
            outputs = Output()
            outputs.add_as_json(res)
        except Exception as e:
            logging.exception("inference failed")
            # error handling
            outputs = Output().error(str(e))
        
        print(outputs)
        print(type(outputs))
        print("Returning inference---------")
        return outputs


_service = BartModel()


def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # stateful model
        _service.initialize(inputs.get_properties())
    
    if inputs.is_empty():
        return None

    return _service.inference(inputs)

In [None]:
%%writefile requirements.txt
numpy

In [None]:
%%writefile serving.properties
engine=Python
option.model_id=facebook/bart-large
option.task=feature-extraction

### SageMaker Hosting

In [None]:
import sagemaker, boto3
from sagemaker import image_uris
import subprocess
import time
from time import gmtime, strftime

boto_session = boto3.session.Session()
s3 = boto_session.resource('s3')
client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")

instance_type = "ml.g5.12xlarge"
role = sagemaker.get_execution_role()  # execution role for the endpoint
session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = session._region_name
bucket = session.default_bucket()  # bucket to house artifacts

img_uri = image_uris.retrieve(framework="djl-deepspeed", region=region, version="0.21.0")
img_uri

In [None]:
#Build tar file with model data + inference code
bashCommand = "tar -cvpzf model.tar.gz model.py requirements.txt serving.properties"
process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
output, error = process.communicate()

In [None]:
#Upload tar.gz to bucket
model_artifacts = f"s3://{bucket}/model.tar.gz"
response = s3.meta.client.upload_file('model.tar.gz', bucket, 'model.tar.gz')

In [None]:
model_artifacts

In [None]:
model_name = "djl-bart" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + model_name)
create_model_response = client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": img_uri, "ModelDataUrl": model_artifacts},
)
print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
endpoint_config_name = "djl-bart" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

production_variants = [
    {
        "VariantName": "AllTraffic",
        "ModelName": model_name,
        "InitialInstanceCount": 1,
        "InstanceType": instance_type,
        "ModelDataDownloadTimeoutInSeconds": 1800,
        "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
    }
]

endpoint_config = {
    "EndpointConfigName": endpoint_config_name,
    "ProductionVariants": production_variants,
}

endpoint_config_response = client.create_endpoint_config(**endpoint_config)
print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

In [None]:
endpoint_name = "djl-bart" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)
print(describe_endpoint_response)

### Sample Inference

In [None]:
response = runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/plain",
    Body="I think my dog is really cute!")
result = json.loads(response['Body'].read().decode())
print(result)