## DJL MME SKLearn Regression Example

In this example we will train a dummy SKLearn Regression model, make a 1000 copies to scale up a sample MME Endpoint with DJL Serving as the backend on a CPU based instance.

### Imports

In [None]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path
import boto3
import json
import os
import joblib
import pickle
import tarfile
import sagemaker
from sagemaker.estimator import Estimator
import time
from time import gmtime, strftime
import subprocess

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name
account_id = sess.account_id()
s3_model_prefix = "djl-mme-sklearn-regression" 

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

### Local Model Training

We will train a sample Linear Regression Model on some dummy numpy data to create our joblib artifact we need for inference.

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

In [None]:
# Generate dummy data
np.random.seed(0)
X = np.random.rand(100, 1)
y = 2 * X + 1 + 0.1 * np.random.randn(100, 1)  

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create a Linear Regression model
model = LinearRegression()

# Train the model on the training data
model.fit(X_train, y_train)

In [None]:
# Save the trained model to a file
import joblib
model_filename = "model.joblib"
joblib.dump(model, model_filename)

In [None]:
serialized_model = joblib.load(model_filename)

In [None]:
# sample inference
payload = [[0.5]]
res = serialized_model.predict(payload).tolist()[0]
res

### DJL Artifact Creation

We now have our model artifact, but we need the following for our DJL Serving Engine

- model.py: Inference script with custom model loading + pre/processing code
- requirements.txt: Additional dependencies, in this case we need to install sklearn and numpy
- serving.properties: Environment variables for DJL Serving, can adjust number of workers here.

In [None]:
%%writefile model.py
#!/usr/bin/env python
#
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import logging
import numpy as np
import time
import os
import joblib
from djl_python import Input
from djl_python import Output


class SKLearnRegressor(object):
    def __init__(self):
        self.initialized = False

    def initialize(self, properties: dict):
        """
        Initialize model.
        """
        print(os.listdir())
        if os.path.exists("model.joblib"):
            self.model = joblib.load(os.path.join("model.joblib"))
        else:
            raise ValueError("Expecting a model.joblib artifact for SKLearn Model Loading")
        self.initialized = True

    def inference(self, inputs):
        """
        Custom service entry point function.

        :param inputs: the Input object holds a list of numpy array
        :return: the Output object to be send back
        """

        #sample input: [[0.5]]
        
        try:
            data = inputs.get_as_json()
            print(data)
            print(type(data))
            res = self.model.predict(data).tolist()[0]
            outputs = Output()
            outputs.add_as_json(res)
        except Exception as e:
            logging.exception("inference failed")
            # error handling
            outputs = Output().error(str(e))
        
        print(outputs)
        print(type(outputs))
        print("Returning inference---------")
        return outputs


_service = SKLearnRegressor()


def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # stateful model
        _service.initialize(inputs.get_properties())
    
    if inputs.is_empty():
        return None

    return _service.inference(inputs)

In [None]:
%%writefile requirements.txt
numpy==1.22.3
joblib
scikit-learn==0.23.2

In [None]:
%%writefile serving.properties
engine=Python
# idle time in seconds before the worker thread is scaled down, the default is 
max_idle_time=600

### Tarball Creation

In [None]:
#Build tar file with model data + inference code, replace this cell with your model.joblib
bashCommand = "tar -cvpzf model.tar.gz model.joblib requirements.txt model.py serving.properties"
process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
output, error = process.communicate()

In [None]:
%%time
# we make a 1000 copies of the tarball as a dummy, you can replace this with your actual model.joblibs in tarball
for i in range(1000):
    with open("model.tar.gz", "rb") as f:
        s3_client.upload_fileobj(f, bucket, "{}/sklearn-{}.tar.gz".format(s3_model_prefix,i))

In [None]:
mme_artifacts = "s3://{}/{}/".format(bucket, s3_model_prefix)
mme_artifacts

In [None]:
#verify all 1000 tar balls are present
!aws s3 ls {mme_artifacts}

### MME Model Creation

In [None]:
# replace this with your ECR image URI based off of your region, we are utilizing the CPU image here
inference_image_uri = '474422712127.dkr.ecr.us-east-1.amazonaws.com/djl-serving-cpu:latest'

In [None]:
mme_model_name = "sklearn-djl-mme" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + mme_model_name)

create_model_response = sm_client.create_model(
    ModelName=mme_model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "Mode": "MultiModel", "ModelDataUrl": mme_artifacts},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

### MME Endpoint Config Creation

In [None]:
#Step 2: EPC Creation
mme_epc_name = "sklearn-djl-mme-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=mme_epc_name,
    ProductionVariants=[
        {
            "VariantName": "sklearnvariant",
            "ModelName": mme_model_name,
            "InstanceType": "ml.c5d.18xlarge",
            "InitialInstanceCount": 20
        },
    ],
)
print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

### MME Endpoint Creation

In [None]:
#Step 3: EP Creation
mme_endpoint_name = "sklearn-djl-ep-mme" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=mme_endpoint_name,
    EndpointConfigName=mme_epc_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
#Monitor creation
describe_endpoint_response = sm_client.describe_endpoint(EndpointName=mme_endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = sm_client.describe_endpoint(EndpointName=mme_endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)
print(describe_endpoint_response)

### Sample Inference

In [None]:
import json
content_type = "application/json"
request_body = '[[0.5]]' #replace with your request body

In [None]:
response = smr_client.invoke_endpoint(
    EndpointName=mme_endpoint_name,
    ContentType=content_type,
    TargetModel = "sklearn-902.tar.gz",
    Body=request_body)
result = json.loads(response['Body'].read().decode())
print(result)

In [None]:
%%time

for i in range(100):
    response = smr_client.invoke_endpoint(
    EndpointName=mme_endpoint_name,
    ContentType=content_type,
    TargetModel = "sklearn-902.tar.gz",
    Body=request_body)

### Cleanup

In [None]:
sm_client.delete_endpoint(EndpointName = mme_endpoint_name)