# Amazon SageMaker Multi-Model Endpoints using PyTorch

> *This notebook works well with SageMaker Studio kernel `Python 3 (Data Science)`, or SageMaker Notebook Instance kernel `conda_python3`*

With [Amazon SageMaker multi-model endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html), customers can create an endpoint that seamlessly hosts up to thousands of models. These endpoints are well suited to use cases where any one of a large number of models, which can be served from a common inference container, needs to be invokable on-demand and where it is acceptable for infrequently invoked models to incur some additional latency. For applications which require consistently low inference latency, a traditional endpoint is still the best choice.

In [None]:
# Python Built-Ins:
import os
import json
import logging
import time

# External Dependencies:
import boto3
import numpy as np
import sagemaker
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.pytorch import PyTorch as PyTorchEstimator, PyTorchModel

smsess = sagemaker.Session()
role = sagemaker.get_execution_role()

# Configuration:
bucket_name = smsess.default_bucket()
prefix = "mnist/"
output_path = f"s3://{bucket_name}/{prefix[:-1]}"

## The example use case: MNIST

MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images.

In this example, we download the MNIST data from a public S3 bucket and upload it to your default SageMaker bucket.


In [None]:
def fetch_sample_data(to_bucket, to_prefix, from_bucket="sagemaker-sample-files", from_prefix="datasets/image/MNIST", dataset="mnist-train"):
    DATASETS = {
        "mnist-train": ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz"],
        "mnist-test": ["t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"],
    }

    if dataset not in DATASETS:
        raise ValueError(f"dataset '{dataset}' not in known set: {set(DATASETS.keys())}")

    if len(from_prefix) and not from_prefix.endswith("/"):
        from_prefix += "/"
    if len(to_prefix) and not to_prefix.endswith("/"):
        to_prefix += "/"

    s3client = boto3.client("s3")
    for key in DATASETS[dataset]:
        s3client.copy_object(
            CopySource={
                "Bucket": from_bucket,
                "Key": f"{from_prefix}{key}",
            },
            Bucket=to_bucket,
            Key=f"{to_prefix}{key}",
        )

train_prefix = f"{prefix}data/train"
fetch_sample_data(to_bucket=bucket_name, to_prefix=train_prefix, dataset="mnist-train")
train_s3uri = f"s3://{bucket_name}/{train_prefix}"
print(f"Uploaded training data to {train_s3uri}")

test_prefix = f"{prefix}data/test"
fetch_sample_data(to_bucket=bucket_name, to_prefix=test_prefix, dataset="mnist-test")
test_s3uri = f"s3://{bucket_name}/{test_prefix}"
print(f"Uploaded training data to {test_s3uri}")

In [None]:
print("Training data:")
!aws s3 ls --recursive $train_s3uri
print("Test data:")
!aws s3 ls --recursive $test_s3uri

## Train multiple models

In the following section, we'll train multiple models on the same dataset, using the SageMaker PyTorch Framework Container.

To keep things simple, we'll create two models `A` and `B`, using the same code but some slightly different hyperparameters.

In [None]:
def get_estimator(base_job_name, hyperparam_overrides={}):
    hyperparameters = {
        'batch-size':128,
        'epochs':20,
        'learning-rate': 1e-3,
        'log-interval':100,
    }
    for k, v in hyperparam_overrides.items():
        hyperparameters[k] = v

    return PyTorchEstimator(
        base_job_name=base_job_name,
        entry_point='train.py',
        source_dir='code', # directory of your training script
        role=role,
        framework_version='1.8',
        py_version='py3',
        instance_type="ml.c4.xlarge",
        instance_count=1,
        output_path=output_path,
        hyperparameters=hyperparameters,
    )

estimatorA = get_estimator(base_job_name="mnist-a", hyperparam_overrides={ "weight-decay": 1e-4 })
estimatorB = get_estimator(base_job_name="mnist-b", hyperparam_overrides={ "weight-decay": 1e-2 })

In [None]:
estimatorA.fit({ "training": train_s3uri, "testing": test_s3uri }, wait=False)
print("Started estimator A training in background (logs will not show)")

print("Training estimator B with logs:")
estimatorB.fit({ "training": train_s3uri, "testing": test_s3uri })

print("\nWaiting for estimator A to complete:")
estimatorA.latest_training_job.wait(logs=False)

## Check single-model deployment

Before we try to set up a multi-model deployment, let's quickly check we're able to deploy and call a single model as expected:

In [None]:
modelA = estimatorA.create_model(name="mnist-a", role=role, source_dir="code", entry_point="inference.py")

In [None]:
predictorA = modelA.deploy(
    initial_instance_count=1,
    instance_type="ml.c5.xlarge",
)
predictorA.serializer = sagemaker.serializers.JSONSerializer()
predictorA.deserializer = sagemaker.deserializers.JSONDeserializer()

In [None]:
def get_dummy_request():
    """Create a dummy predictor.predict example data (16 images of random pixels)"""
    return {
        "inputs": np.random.rand(16, 1, 28, 28).tolist()
    }

dummy_data = get_dummy_request()

start_time = time.time()
predicted_value = predictorA.predict(dummy_data)
duration = time.time() - start_time

print(f"Model took {int(duration * 1000):,d} ms")
np.array(predicted_value)[0]

## Create the multi-model endpoint with the SageMaker SDK

### Create a SageMaker Model from one of the Estimators

In [None]:
model = estimatorA.create_model(role=role, source_dir="code", entry_point="inference.py")

### Create the Amazon SageMaker MultiDataModel entity

We create the multi-model endpoint using the [```MultiDataModel```](https://sagemaker.readthedocs.io/en/stable/api/inference/multi_data_model.html) class.

You can create a MultiDataModel by directly passing in a `sagemaker.model.Model` object - in which case, the Endpoint will inherit information about the image to use, as well as any environmental variables, network isolation, etc., once the MultiDataModel is deployed.

In addition, a MultiDataModel can also be created without explictly passing a `sagemaker.model.Model` object. Please refer to the documentation for additional details.

In [None]:
# This is where our MME will read models from on S3.
multi_model_prefix = f"{prefix}multi-model/"
multi_model_s3uri = f"s3://{bucket_name}/{multi_model_prefix}"
print(multi_model_s3uri)

In [None]:
mme = MultiDataModel(
    name="mnist-multi",
    model_data_prefix=multi_model_s3uri,
    model=model,  # passing our model
    sagemaker_session=smsess,
)

## Deploy the Multi Model Endpoint

You need to consider the appropriate instance type and number of instances for the projected prediction workload across all the models you plan to host behind your multi-model endpoint. The number and size of the individual models will also drive memory requirements.

In [None]:
try:
    predictor.delete_endpoint(delete_endpoint_configuration=True)
    print("Deleting previous endpoint...")
    time.sleep(10)
except NameError:
    pass

predictor = mme.deploy(
    endpoint_name="mnist-multi",
    initial_instance_count=1,
    instance_type="ml.c5.xlarge",
)

### Our endpoint has launched! Let's look at what models are available to the endpoint!

By 'available', what we mean is, what model artfiacts are currently stored under the S3 prefix we defined when setting up the `MultiDataModel` above i.e. `model_data_prefix`.

Currently, since we have no artifacts (i.e. `tar.gz` files) stored under  our defined S3 prefix, our endpoint, will have no models 'available' to serve inference requests.

We will demonstrate how to make models 'available' to our endpoint below.

In [None]:
# No models visible!
list(mme.list_models())

### Lets deploy model artifacts to be found by the endpoint

We are now using the `.add_model()` method of the `MultiDataModel` to copy over our model artifacts from where they were initially stored, during training, to where our endpoint will source model artifacts for inference requests.

`model_data_source` refers to the location of our model artifact (i.e. where it was deposited on S3 after training completed)

`model_data_path` is the **relative** path to the S3 prefix we specified above (i.e. `model_data_prefix`) where our endpoint will source models for inference requests.

Since this is a **relative** path, we can simply pass the name of what we wish to call the model artifact at inference time (i.e. `Chicago_IL.tar.gz`)

### Dynamically deploying additional models

It is also important to note, that we can always use the `.add_model()` method, as shown below, to dynamically deploy more models to the endpoint, to serve up inference requests as needed.

In [None]:
for name, est in { "ModelA": estimatorA, "ModelB": estimatorB }.items():
    artifact_path = est.latest_training_job.describe()["ModelArtifacts"]["S3ModelArtifacts"]
    #model_name = artifact_path.split('/')[-4]+'.tar.gz'
    # This is copying over the model artifact to the S3 location for the MME.
    mme.add_model(model_data_source=artifact_path, model_data_path=name)

## We have added the 4 model artifacts from our training jobs!

We can see that the S3 prefix we specified when setting up `MultiDataModel` now has 4 model artifacts. As such, the endpoint can now serve up inference requests for these models.

In [None]:
list(mme.list_models())

# Get predictions from the endpoint

Recall that ```mme.deploy()``` returns a [RealTimePredictor](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/predictor.py#L35) that we saved in a variable called ```predictor```.

We will use ```predictor``` to submit requests to the endpoint.

In [None]:
dummy_data = get_dummy_request()

start_time = time.time()
predicted_value = predictor.predict(dummy_data, target_model="ModelA")
duration = time.time() - start_time

print(f"Model took {int(duration * 1000):,d} ms")
np.array(predicted_value)[0]

## Clean-up

In [None]:
predictorA.delete_endpoint(delete_endpoint_config=True)
predictor.delete_endpoint(delete_endpoint_config=True)