# Building a custom inference container
1. [Part 1: Packaging your code for inference with Amazon SageMaker](#Part-1:-Packaging-your-code-for-inference-with-Amazon-SageMaker)
    1. [How Amazon SageMaker runs your Docker container during hosting](#How-Amazon-SageMaker-runs-your-Docker-container-during-hosting)
    1. [The parts of the sample container](#The-parts-of-the-sample-inference-container)
       1. [Creating an inference handler](#Creating-an-inference-handler)
       1. [Implement a handler service](#Implement-a-handler-service)       
       1. [Implement an entrypoint](#Implement-an-entrypoint)              
    1. [The Dockerfile](#The-Dockerfile)
1. [Part 2: Building and registering the container](#Part-2:-Building-and-registering-the-container)
1. [Part 3: Use the container for inference in Amazon SageMaker](#Part-3:-Use-the-container-for-inference-in-Amazon-SageMaker)
  1. [Import model into hosting](#Import-model-into-hosting)
  1. [Create endpoint configuration](#Create-endpoint-configuration) 
  1. [Create endpoint](#Create-endpoint)   
  1. [Invoke model](#Invoke-model)     
1. [(Optional) cleanup](#(Optional)-cleanup)  

## Part 1: Packaging your code for inference with Amazon SageMaker

### How Amazon SageMaker runs your Docker container during hosting

Because you can run the same image in training or hosting, Amazon SageMaker runs your container with the argument `train` or `serve`. How your container processes this argument depends on the container. All SageMaker framework containers already cover this requirement and will trigger your defined training algorithm and inference code.

* If you specify a program as an `ENTRYPOINT` in the Dockerfile, that program will be run at startup and its first argument will be `train` or `serve`. The program can then look at that argument and decide what to do.

#### Running your container during hosting

Hosting has a very different model than training because hosting is reponding to inference requests that come in via HTTP. 

Amazon SageMaker uses two URLs in the container:

* `/ping` receives `GET` requests from the infrastructure. Your program returns 200 if the container is up and accepting requests.
* `/invocations` is the endpoint that receives client inference `POST` requests. The format of the request and the response is up to the algorithm. If the client supplied `ContentType` and `Accept` headers, these are passed in as well. 

If you are using the same container image for both training and serving the model, it will have the model files in the same place that they were written to during training:

    /opt/ml
    `-- model
        `-- <model files>
        
Alternatively, if you are using separate containers for training and inference, when the inference container is spun up, the model files will be copied from the S3 location that the training container outputted them to. 

### The parts of the sample inference container

The `inference_container` directory has all the components you need to extend the inference logic of the SageMaker scikit-learn container:

    .
    |-- Dockerfile
    |-- handler_service.py
    |-- serve.py

Let's discuss each of these in turn:

* __`Dockerfile`__ describes how to build your Docker container image. More details are provided below.
* __`handler_service.py`__ is the program that defines a handler service, together with an inference handler to load the model, pre-process input data, get predictions and output data
* __`serve.py`__ is the entrypoint for the application

In this simple application, we install only one file in the container. You may only need that many, but if you have many supporting routines, you may wish to install more.

### Creating an inference handler

The [SageMaker inference toolkit](https://github.com/aws/sagemaker-inference-toolkit) is built on the multi-model server (MMS). MMS expects a Python script that implements functions to load the model, pre-process input data, get predictions from the model, and process the output data in a model handler.

#### The model_fn Function

The model_fn function is responsible for loading your model. It takes a `model_dir` argument that specifies where the model is stored. How you load your model depends on the framework you are using. There is no default implementation for the `model_fn` function. You must implement it yourself. This is how the `model_fn` function that loads the LightFM model would look:

In [None]:
def model_fn(self, model_dir):
    import pickle
    
    logger.info('Loading LightFM model...')
    return pickle.load(open( "model.pickle", "rb" ))

#### The input_fn Function

The `input_fn` function is responsible for deserializing your input data so that it can be passed to your model. It takes input data and content type as parameters, and returns deserialized data. The SageMaker inference toolkit provides a default implementation that deserializes the following content types:

* JSON
* CSV
* Numpy array
* NPZ

If your model requires a different content type, or you want to preprocess your input data before sending it to the model, you must implement the `input_fn` function. The following example shows a simple implementation of the `input_fn` function.

In [None]:
from sagemaker_inference import content_types, decoder
def input_fn(self, input_data, content_type):
        """A default input_fn that can handle JSON, CSV and NPZ formats.
         
        Args:
            input_data: the request payload serialized in the content_type format
            content_type: the request content_type

        Returns: JSON
        """
        return decoder.decode(input_data, content_type)

#### The predict_fn Function

The predict_fn function is responsible for getting predictions from the model. It takes the model and the data returned from input_fn as parameters, and returns the prediction. There is no default implementation for the predict_fn. You must implement it yourself. The following is a simple implementation of the predict_fn function for a PyTorch model.

In [None]:
def predict_fn(self, data, model):
        """A default predict_fn for. Calls a model on data deserialized in input_fn.

        Args:
            data: input data (numpy array) for prediction deserialized by input_fn
            model: LightFM model loaded in memory by model_fn

        Returns: a prediction
        """
        
        f: lambda x: model.predict(x)
            
        return f(data)

#### The output_fn Function
The output_fn function is responsible for serializing the data that the predict_fn function returns as a prediction. The SageMaker inference toolkit implements a default output_fn function that serializes Numpy arrays, JSON, and CSV. If your model outputs any other content type, or you want to perform other post-processing of your data before sending it to the user, you must implement your own output_fn function. The following shows a simple output_fn function for our model.

In [None]:
from sagemaker_inference import encoder
def output_fn(self, prediction, accept):
        """A default output_fn. Serializes predictions from predict_fn to JSON, CSV or NPY format.

        Args:
            prediction: a prediction result from predict_fn
            accept: type which the output data needs to be serialized

        Returns: output data serialized
        """
        return encoder.encode(prediction, accept)

### Implement a handler service
The handler service is executed by the model server. The handler service implements initialize and handle methods. The initialize method is invoked when the model server starts, and the handle method is invoked for all incoming inference requests to the model server. For more information, see [Custom Service in the Multi-model server documentation](https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md). The following is an example of a handler service for our model server.

In [None]:
from sagemaker_inference.default_handler_service import DefaultHandlerService
from sagemaker_inference.transformer import Transformer
from sagemaker_pytorch_serving_container.default_inference_handler import DefaultPytorchInferenceHandler


class HandlerService(DefaultHandlerService):
    """Handler service that is executed by the model server.
    Determines specific default inference handlers to use based on model being used.
    This class extends ``DefaultHandlerService``, which define the following:
        - The ``handle`` method is invoked for all incoming inference requests to the model server.
        - The ``initialize`` method is invoked at model server start up.
    Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/custom_service.md
    """
    def __init__(self):
        transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler())
        super(HandlerService, self).__init__(transformer=transformer)

### Implement an entrypoint
The entrypoint starts the model server by invoking the handler service. You specify the location of the entrypoint in your Dockerfile. The following is an example of an entrypoint.

In [None]:
from sagemaker_inference import model_server

model_server.start_model_server(handler_service=HANDLER_SERVICE)

### The Dockerfile
In your Dockerfile, copy the model handler from step 2 and specify the Python file from the previous step as the entrypoint in your Dockerfile. The following is an example of the lines you can add to your Dockerfile to copy the model handler and specify the entrypoint. 

In [None]:
# Copy the default custom service file to handle incoming data and inference requests
COPY model_handler.py /home/model-server/model_handler.py

# Define an entrypoint script for the docker image
ENTRYPOINT ["python", "/usr/local/bin/serve.py"]

# Part 2: Building and registering the container

Just like with the training container, we are going to use the [Amazon SageMaker Studio Image Build new CLI](https://aws.amazon.com/blogs/machine-learning/using-the-amazon-sagemaker-studio-image-build-cli-to-build-container-images-from-your-studio-notebooks/).

Open a terminal window and run the following command:
```
cd ~/inference_container
sm-docker build . --repository lightfm-inference:1.0
```

# Part 3: Use the container for inference in Amazon SageMaker

In [None]:
import boto3 
def get_container_uri(ecr_repository, tag):
    account_id = boto3.client('sts').get_caller_identity().get('Account')

    region = boto3.session.Session().region_name

    uri_suffix = 'amazonaws.com'
    if region in ['cn-north-1', 'cn-northwest-1']:
        uri_suffix = 'amazonaws.com.cn'

    return '{}.dkr.ecr.{}.{}/{}:{}'.format(account_id, region, uri_suffix, ecr_repository, tag)

# Import model into hosting

When creating the Model entity for endpoints, the container's ModelDataUrl is the S3 prefix where the model artifacts that are invokable by the endpoint are located. The rest of the S3 path will be specified when invoking the model.

The Mode of container is specified as MultiModel to signify that the container will host multiple models.

In [None]:
import boto3
from sagemaker import get_execution_role
from time import gmtime, strftime

role = get_execution_role()
client = boto3.client(service_name='sagemaker')

byoc_image_uri = get_container_uri('lightfm-inference','1.0')
model_url = 's3://sagemaker-us-east-1-718026778991/light-fm-custom-container-train-job-2021-04-26-10-26-53-215/output/model.tar.gz'
model_name = 'Demo-LightFM-Inference-Model'+ strftime("%Y-%m-%d-%H-%M-%S", gmtime())

container = {
    'Image': byoc_image_uri,
    'ModelDataUrl': model_url,
    'Mode': 'SingleModel'
}

create_model_response = client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    Containers = [container])

print("Model Arn: " + create_model_response['ModelArn'])

# Create endpoint configuration

In [None]:
endpoint_config_name = 'DEMO-LightFM-EndpointConfig-' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print('Endpoint config name: ' + endpoint_config_name)

create_endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType': 'ml.m5.xlarge',
        'InitialInstanceCount': 1,
        'InitialVariantWeight': 1,
        'ModelName': model_name,
        'VariantName': 'AllTraffic'}])

print("Endpoint config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

# Create endpoint

In [None]:
import time

endpoint_name = 'DEMO-LightFMEndpoint-' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print('Endpoint name: ' + endpoint_name)

create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name)
print('Endpoint Arn: ' + create_endpoint_response['EndpointArn'])

resp = client.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Endpoint Status: " + status)

print('Waiting for {} endpoint to be in service...'.format(endpoint_name))
waiter = client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=endpoint_name)

# Invoke model

Now we invoke the model that we uploaded to S3 previously in the training step. 

The first invocation of a model may be slow, since behind the scenes, SageMaker is downloading the model artifacts from S3 to the instance and loading it into the container.

In [None]:
%%time

import json
import numpy as np

runtime_client = boto3.client(service_name='sagemaker-runtime')

data = np.array([3, 42, 500])
payload = json.dumps(data.tolist())

response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType='application/json',
#    TargetModel='resnet_18.tar.gz', # this is the rest of the S3 path where the model artifacts are located
    Body=payload)

print(*json.loads(response['Body'].read()), sep = '\n')

## (Optional) cleanup
When you're done with the endpoint, you should clean it up.

All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account, but you can also run the code below to easily clean up the resources.

In [None]:
client.delete_endpoint(EndpointName=endpoint_name)
client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
client.delete_model(ModelName=model_name)