# PyTorch Pre-Trained Model Deployment Example

In this example we'll take a look at taking a pre-trained SageMaker PyTorch example and deploying it on SageMaker Real-Time Inference. We'll take a sample local PyTorch model train it on artifical data and then deploy that trained model artifact to a SageMaker Endpoint. The idea here is to show the general SageMaker deplyoment flow utilizing the AWS Boto3 Python SDK.

## Local Model Training

Taking a sample PyTorch model for local training, we will take the serialized model artifacts and deploy them for inference. In this case the model artifacts (model.pth) is what we will generate these varies depending on the model. For instance SKLearn might have a model.joblib whereas for LLMs they'll have a variety of metadata files (.json, tensors, etc).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# artificial data for lin reg
torch.manual_seed(42)
X = 3 * torch.rand(100, 1)
y = 3 * X + 2 + 0.1 * torch.randn(100, 1)

# lin reg model
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

# train model
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
num_epochs = 100
for epoch in range(num_epochs):
    y_pred = model(X)

In [None]:
# serialize model data
torch.save(model.state_dict(), 'model.pth')

In [None]:
# load model
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load("model.pth"))

# sample inference
samp_data = [[2.5]]
with torch.no_grad():
    prediction = loaded_model(torch.tensor(samp_data))
output = prediction.tolist()
output

## SageMaker Deployment

For SageMaker Deployment there are a few key constructs:

- SageMaker Model Object: Points towards model data and any inference artifacts.
- SageMaker Endpoint Configuration: Defines hardware for the model.
- SageMaker Endpoint: The persistent REST Endpoint for invocation, can attach AutoScaling.

### Custom Inference Script

In cases you want to control model loading, pre/post processing you can define your own inference scripts to override the default handlers of the model server that the container exposes. In this case these are the four functions for PyTorch that can be overriden:

- model_fn: Load model
- input_fn: Handle input + preprocessing
- output_fn: Handle output and structure it necessarily
- predict_fn: Control model inference

We attach this inference script to our model data and package it into a model.tar.gz that SageMaker expects. The packaging of this model.tar.gz is dependent on the model server/container you are using, each model server expects a different file structure.

In [None]:
%%writefile inference.py
import os
import json
import torch
import torch.nn as nn

class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)
    

# load model
def model_fn(model_dir): 
    model = LinearRegressionModel()   
    with open(os.path.join(model_dir, "model.pth"), "rb") as f:
        model.load_state_dict(torch.load(f))
    return model

# preprocessing input
def input_fn(request_body, request_content_type):
    assert request_content_type == "application/json"
    data = json.loads(request_body)["inputs"]
    input_data = torch.tensor(data)
    return input_data

# inference
def predict_fn(input_object, model):
    with torch.no_grad():
        prediction = model(input_object)
    output = prediction.tolist()
    return output

### SageMaker Objects Creation

In [None]:
import boto3
import json
import os
import joblib
import pickle
import tarfile
import sagemaker
from sagemaker.estimator import Estimator
import time
from time import gmtime, strftime
import subprocess

#Setup
client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")
boto_session = boto3.session.Session()
s3 = boto_session.resource('s3')
region = boto_session.region_name
print(region)
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

In [None]:
#Build tar file with model data + inference code
bashCommand = "tar -cvpzf model.tar.gz model.pth inference.py"
process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
output, error = process.communicate()

In [None]:
# retrieve pytorch image
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=region,
    version="2.1",
    py_version="py310",
    image_scope="inference",
    instance_type="ml.m5.xlarge"
)

In [None]:
#Bucket for model artifacts
default_bucket = sagemaker_session.default_bucket()
print(default_bucket)

#Upload tar.gz to bucket
model_artifacts = f"s3://{default_bucket}/model.tar.gz"
response = s3.meta.client.upload_file('model.tar.gz', default_bucket, 'model.tar.gz')

In [None]:
#Step 1: Model Creation
model_name = "pytorch-test" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + model_name)
create_model_response = client.create_model(
    ModelName=model_name,
    Containers=[
        {
            "Image": image_uri,
            "Mode": "SingleModel",
            "ModelDataUrl": model_artifacts,
            "Environment": {'SAGEMAKER_SUBMIT_DIRECTORY': model_artifacts,
                           'SAGEMAKER_PROGRAM': 'inference.py'} 
        }
    ],
    ExecutionRoleArn=role,
)
print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
#Step 2: EPC Creation
epc_name = "pt-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=epc_name,
    ProductionVariants=[
        {
            "VariantName": "ptvariant",
            "ModelName": model_name,
            "InstanceType": "ml.c5.large",
            "InitialInstanceCount": 1
        },
    ],
)
print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

In [None]:
#Step 3: EP Creation
endpoint_name = "pt-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=epc_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
#Monitor creation
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)
print(describe_endpoint_response)

## Sample Inference

You can use the [SDK runtime client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime/client/invoke_endpoint.html) to directly invoke the endpoint.

In [None]:
runtime_client = boto3.client('sagemaker-runtime')
content_type = "application/json"
request_body = {"inputs": [[2.5]]}
data = json.loads(json.dumps(request_body))
payload = json.dumps(data)

response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Body=payload)
result = json.loads(response['Body'].read().decode())
result