# Pre-Trained SKLearn Model Deployment on SageMaker Real-Time Endpoints

In this sample we take a dummy SKLearn regression model and showcase how you can deploy it to a [SageMaker Real-Time Endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) using the [Boto3 AWS Python SDK](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html) and higher level [SageMaker Python SDK](https://github.com/aws/sagemaker-python-sdk) in conjunction.

## Setup
We will be working in a ml.c5.large in SageMaker Studio using JupyterLab. We then install the SDKs we are utilizing to interact with SageMaker along with scikit-learn for some dummy local model training.

In [None]:
!pip install -U sagemaker boto3 scikit-learn

In [None]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path
import boto3
import json
import os
import joblib
import pickle
import tarfile
import sagemaker
from sagemaker.estimator import Estimator
import time
from time import gmtime, strftime
import subprocess

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name
account_id = sess.account_id()
s3_model_prefix = "djl-sme-sklearn-regression" 

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

## Sample Local Model Training
Here we generate some artificial data and train a SKLearn Linear Regression model on it and capture the model artifacts which is a joblib file in this case.

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

In [None]:
# Generate dummy data
np.random.seed(0)
X = np.random.rand(100, 1)
y = 2 * X + 1 + 0.1 * np.random.randn(100, 1)  

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create a Linear Regression model
model = LinearRegression()

# Train the model on the training data
model.fit(X_train, y_train)

In [None]:
# Save the trained model to a file
import joblib
model_filename = "model.joblib"
joblib.dump(model, model_filename)

In [None]:
serialized_model = joblib.load(model_filename)

In [None]:
# sample inference
payload = [[0.5]]
res = serialized_model.predict(payload).tolist()[0]
res

## SageMaker Artifact Setup
SageMaker expects a model.tar.gz with the model data/weights and any inference scripts. Here we prepare our inference script in the format that our Model Server/Container in [DJL Serving](https://github.com/deepjavalibrary/djl-serving/tree/master) expects. Note that each model server has a different protocol or format for which it may expect the artifacts to be packaged.

### Inference Script Creation
Here we can customize model loading, pre/post processing, for DJL Serving the handle method is what must be implemented and picked up on by the model server.

In [None]:
%%writefile model.py
#!/usr/bin/env python
#
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import logging
import numpy as np
import time
import os
import joblib
from djl_python import Input
from djl_python import Output


class SKLearnRegressor(object):
    def __init__(self):
        self.initialized = False

    def initialize(self, properties: dict):
        """
        Initialize model.
        """
        print(os.listdir())
        if os.path.exists("model.joblib"):
            self.model = joblib.load(os.path.join("model.joblib"))
        else:
            raise ValueError("Expecting a model.joblib artifact for SKLearn Model Loading")
        self.initialized = True

    def inference(self, inputs):
        """
        Custom service entry point function.

        :param inputs: the Input object holds a list of numpy array
        :return: the Output object to be send back
        """

        #sample input: [[0.5]]
        
        try:
            data = inputs.get_as_json()
            print(data)
            print(type(data))
            res = self.model.predict(data).tolist()[0]
            outputs = Output()
            outputs.add_as_json(res)
        except Exception as e:
            logging.exception("inference failed")
            # error handling
            outputs = Output().error(str(e))
        
        print(outputs)
        print(type(outputs))
        print("Returning inference---------")
        return outputs


_service = SKLearnRegressor()


def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # stateful model
        _service.initialize(inputs.get_properties())
    
    if inputs.is_empty():
        return None

    return _service.inference(inputs)

In [None]:
%%writefile requirements.txt
numpy
joblib
scikit-learn

In [None]:
%%writefile serving.properties
engine=Python

In [None]:
#Build tar file with model data + inference code, replace this cell with your model.joblib
bashCommand = "tar -cvpzf model.tar.gz model.joblib requirements.txt model.py serving.properties"
process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
output, error = process.communicate()

In [None]:
# upload model data to S3
with open("model.tar.gz", "rb") as f:
    s3_client.upload_fileobj(f, bucket, "{}/model.tar.gz".format(s3_model_prefix))

In [None]:
sme_artifacts = "s3://{}/{}/{}".format(bucket, s3_model_prefix, "model.tar.gz")
sme_artifacts

### Container Specification
This is where you specify the container/model server for your model, in this case we use the DJL CPU based image as we are dealing with a smaller CPU based model. For a list of all the managed images by AWS please refer to this link: https://github.com/aws/deep-learning-containers/blob/master/available_images.md. You can also optionally bring your own container where you have your own serving logic implemented, here's a sample of that: https://github.com/RamVegiraju/SageMaker-Deployment/tree/master/RealTime/BYOC/PreTrained-Examples/SpacyNER?source=post_page-----37211d8412f4--------------------------------

In [None]:
# replace this with your ECR image URI based off of your region, we are utilizing the CPU image here
inference_image_uri = '763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.29.0-cpu-full'

## SageMaker Constructs
There are three SageMaker constructs for endpoints, we've linked the three respective API calls as well:

1. [SageMaker Model Object](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_model.html): Points towards model data (model.tar.gz) and container
2. [SageMaker Endpoint Configuration](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_endpoint_config.html): Specifies the hardware and any production variants
3. [SageMaker Endpoint](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_endpoint.html): The persistent REST endpoint that you can invoke and attach scaling policies to

### SageMaker Model Creation

In [None]:
#Step 1: Model Creation
sme_model_name = "sklearn-djl-sme" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + sme_model_name)

create_model_response = sm_client.create_model(
    ModelName=sme_model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "Mode": "SingleModel", "ModelDataUrl": sme_artifacts},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

### SageMaker Endpoint Config Creation

In [None]:
#Step 2: EPC Creation
sme_epc_name = "sklearn-djl-sme-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=sme_epc_name,
    ProductionVariants=[
        {
            "VariantName": "sklearnvariant",
            "ModelName": sme_model_name,
            "InstanceType": "ml.c5.xlarge",
            "InitialInstanceCount": 1
        },
    ],
)
print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

### SageMaker Endpoint Creation
This step can take a few minutes as the endpoint resources are being prepared (can vary depending on hardware you have behind endpoint).

In [None]:
#Step 3: EP Creation
sme_endpoint_name = "sklearn-djl-ep-sme" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=sme_endpoint_name,
    EndpointConfigName=sme_epc_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
#Monitor creation
describe_endpoint_response = sm_client.describe_endpoint(EndpointName=sme_endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = sm_client.describe_endpoint(EndpointName=sme_endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)
print(describe_endpoint_response)

## Sample Invocation
We use the boto3 runtime client (different from client we used to create resources) to invoke the model with the following API call: https://boto3.amazonaws.com/v1/documentation/api/1.35.9/reference/services/sagemaker-runtime/client/invoke_endpoint.html

In [None]:
import json
content_type = "application/json"
request_body = '[[0.5]]' #replace with your request body

In [None]:
response = smr_client.invoke_endpoint(
    EndpointName=sme_endpoint_name,
    ContentType=content_type,
    Body=request_body)
result = json.loads(response['Body'].read().decode())
print(result)

## Cleanup
Ensure to delete your endpoint to avoid incurring further costs.

In [None]:
sm_client.delete_endpoint(EndpointName = sme_endpoint_name)