# Create and Deploy Custom Inference Workflows using the SageMaker PythonSDK.

In this notebook you will learn how to use the `ModelBuilder` class to define and deploy your own custom inference workflows directly in the SageMaker PythonSDK and have them ready for serving. 

You will be able to define the `ResourceRequirements` for multiple Inference Components and deploy them in bulk. For a supported subset of Jumpstart models including the Llama3 family, you don't need to specify `ResourceRequirements` at all and can instead use pre-benchmarked deployment configs.

You can either launch this notebook from an Amazon SageMaker studio notebook instance which handles all credentials automatically, or by running it locally and setting credentials manually. (Please make sure you are running on the latest studio image version >=3.0.0).

***

### Additional Resources
- To learn more about `ModelBuilder`, see [Create a model in Amazon SageMaker with ModelBuilder](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-modelbuilder-creation.html)
- To learn more about Inference Components, see the [Inference Component Launch Blog](https://aws.amazon.com/blogs/aws/amazon-sagemaker-adds-new-inference-capabilities-to-help-reduce-foundation-model-deployment-costs-and-latency/)

## Prerequisites
This notebook was tested locally, as well as in SageMaker Studio using region `us-west-2` and `us-east-1` on the following platforms and configs.
| Platform |Config Name | Value |
| :--------  | :----------- | :---------- |
| JupyterLab |Instance Type| ml.m5.xlarge |
|            |Image         | SageMaker Distribution 2.0.0 |
|            |Kernel        | Python 3 (ipykernel) |
|            |Python Version| Python 3.12|
| Studio Classic |Instance Type| ml.t3.medium |
|            |Image         | Data Science 4.0 |
|            |Kernel        | Python 3 |
|            |Python Version| Python 3.12|


For best performance, use Python 3.12. 

Note that, To enable live-logging when deploying your custom orchestrator, ensure your execution role has `logs:FilterLogEvents` permissions. The default notebook role does NOT have this by default.

In [None]:
%pip install --upgrade pip
%pip install --upgrade sagemaker

### Step 1: Create our Inference Component `ModelBuilder` objects.
Define a `ModelBuilder` object for each Inference Component you'd like to create. Set either `inference_component_name` or `resource_requirements` to signify an IC should be created. 

Note: Some JumpStart models contain pre-benchmarked `ResourceRequirements` deployment configurations, so all you need to do is set an `inference_component_name`. 
This notebook uses `"meta-textgeneration-llama-3-1-8b"` to show off using pre-benchmarked deployment configurations.

In [None]:
import uuid
import boto3
from sagemaker import get_execution_role

# Get the SageMaker execution role
role = get_execution_role()

# Define the names for our inference components and endpoint.
llama_mistral_endpoint_name = f"llama-mistral-endpoint-{uuid.uuid1().hex}"
mistral_ic_name = f"mistral-ic-{uuid.uuid1().hex}"
llama_ic_name = f"llama-ic-{uuid.uuid1().hex}"

region = boto3.Session().region_name

We first set up a ModelBuilder for deploying Llama 3.1 8B model with specific configurations:

1. Model Selection: Uses the Llama 3.1 8B text generation model
2. Schema Definition:
    - Input schema includes the prompt text and generation parameters
    - Output schema defines the expected response format
3. Resource Requirements: Define the resources required to deploy the model inference component. Here we will deploy a single copy of the model

Note that in this notebook we use the default SageMaker execution role, you can follow the [least-privilege permissions](https://docs.aws.amazon.com/sagemaker/latest/dg/security_iam_id-based-policy-examples.html) to manage resources created.

In [None]:
from sagemaker.session import Session
from sagemaker.serve import ModelBuilder
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements


# Define sample input and output for the model
prompt = "Falcons are"
response = "Falcons are small to medium-sized birds of prey related to hawks and eagles."

# Create the input schema structure
sample_input = {
    "inputs": prompt,
    "parameters": {"max_new_tokens": 32}
}
# Define the expected output format
sample_output = [{"generated_text": response}]

# Create a ModelBuilder instance for Llama 3.1 8B
# Pre-benchmarked ResourceRequirements will be taken from JumpStart, as Llama-3.1-8b is a supported model.
llama_model_builder = ModelBuilder(
    model="meta-textgeneration-llama-3-1-8b",
    schema_builder=SchemaBuilder(sample_input, sample_output),
    inference_component_name=llama_ic_name,
    instance_type="ml.g5.24xlarge"
)

In [None]:
mistral_mb = ModelBuilder(
    model="huggingface-llm-mistral-7b",
    schema_builder=SchemaBuilder(sample_input, sample_output),
    inference_component_name=mistral_ic_name,
    resource_requirements=ResourceRequirements(
        requests={
           "memory": 49152,
           "num_accelerators": 2,
           "copies": 1
        }
    ),
    instance_type="ml.g5.24xlarge"
)

### Step 2: Define your custom inference orchestrator by creating a class which inherits from the new `CustomOrchestrator` class.
`CustomOrchestrator` expects a single `handle()` function to be implemented, which will serve as the container entrypoint when your endpoint is invoked. This example uses vanilla Python and boto3.

The following example is merely meant to demonstratge chaining together multiple SageMaker models, and does not necessarily reflect practical real-world use cases.


```
class CustomOrchestrator(ABC):
    """
    Templated class used to standardize the structure of an entry point based inference script.
    """

    @abstractmethod
    def handle(self, data, context=None):
        """abstract class for defining an entrypoint for the model server"""
        return NotImplemented
```

In [None]:
from sagemaker.serve.spec.inference_base import CustomOrchestrator

This implementation creates a pipeline where the output of the Llama model becomes the input for the Mistral model, allowing for sequential processing of text through two different language models within the same endpoint.

In [None]:
import json
class PythonCustomInferenceEntryPoint(CustomOrchestrator):
    def __init__(self, region_name, endpoint_name, component_names):
        self.region_name = region_name
        self.endpoint_name = endpoint_name
        self.component_names = component_names
    
    def preprocess(self, data):
        payload = {
            "inputs": data.decode("utf-8")
        }
        return json.dumps(payload)

    def _invoke_workflow(self, data):
        # First model (Llama) inference
        payload = self.preprocess(data)
        
        llama_response = self.client.invoke_endpoint(
            EndpointName=self.endpoint_name,
            Body=payload,
            ContentType="application/json",
            InferenceComponentName=self.component_names[0]
        )
        llama_generated_text = json.loads(llama_response.get('Body').read())['generated_text']
        
        # Second model (Mistral) inference
        parameters = {
            "max_new_tokens": 50
        }
        payload = {
            "inputs": llama_generated_text,
            "parameters": parameters
        }
        mistral_response = self.client.invoke_endpoint(
            EndpointName=self.endpoint_name,
            Body=json.dumps(payload),
            ContentType="application/json",
            InferenceComponentName=self.component_names[1]
        )
        return {"generated_text": json.loads(mistral_response.get('Body').read())['generated_text']}
    
    def handle(self, data, context=None):
        return self._invoke_workflow(data)

### Step 3: Create a `ModelBuilder` (which represents the custom orchestrator) and pass it in via the `inference_spec` field like so.
`ModelBuilder` will know to deploy as a custom orchestrator if `inference_spec` contains an instance of `CustomOrchestrator`.


`ModelBuilder` will be able to automatically capture module-level dependencies if "auto" is set to True. If you have any dependencies declared elsewhere, set "auto" to False and list whatever packages you may need.

Then pass in the Inference Component ModelBuilder objects we just created.

- `modelbuilder_list` - `ModelBuilder` objects we just created for each Inference Component and the custom orchestrator.

Calling `build()` will prepare the chain for deployment, which can be triggered via `deploy()`.


In [None]:
from sagemaker.serve.builder.model_builder import ModelBuilder, SchemaBuilder
from sagemaker.session import Session
from sagemaker.resource_requirements import ResourceRequirements

custom_orchestrator_name = f"custom-orchestrator-{uuid.uuid1().hex}"

orchestrator = ModelBuilder(
    inference_spec=PythonCustomInferenceEntryPoint(
        region_name=region,
        endpoint_name=llama_mistral_endpoint_name,
        component_names=[llama_ic_name, mistral_ic_name],
    ),
    dependencies={
        "auto": False,
        "custom": [
            "cloudpickle",
            "graphene",
            # Define other dependencies here.
        ],
    },
    sagemaker_session=Session(),
    role_arn=role,
    resource_requirements=ResourceRequirements(
        requests={
           "memory": 4096,
           "num_accelerators": 1,
           "copies": 1,
           "num_cpus": 2
        }
    ),
    inference_component_name=custom_orchestrator_name, # IC name for your custom orchestrator
    name=custom_orchestrator_name, # Endpoint name if you want the custom orchestrator on its own endpoint
    schema_builder=SchemaBuilder(sample_input="Test", sample_output={"generated_text": "test"}),
    modelbuilder_list=[llama_model_builder, mistral_mb] # Inference Component ModelBuilders created in Step 2
)

# call the build function to prepare the chain for deployment
orchestrator.build()

Calling `deploy()` will deploy your inference component ModelBuilders to your desired instance type, and your custom workflow to 1 instance of `ml.c5.xlarge` by default.
You can set `custom_orchestrator_instance_type` and `custom_orchestrator_initial_instance_count` to configure these values.

In [None]:
predictors = orchestrator.deploy(
    instance_type="ml.g5.24xlarge",
    initial_instance_count=1,
    accept_eula=True, # Required for Llama3
    endpoint_name=llama_mistral_endpoint_name,
    # custom_orchestrator_instance_type="ml.t2.medium", # default
    # custom_orchestrator_initial_instance_count=1 # default
)

### Test Invoking the Inference Components and Custom Orchestrator Endpoint

Test the custom orchestrator. 
If your orchestrator supports streaming, you can call `Predictor.predict_stream()` with `"stream": True` set in the payload body to get a streaming response.
```

generator = predictor.predict_stream(
    json.dumps({
        "prompt": "where is the capital of india?",
        "stream": True
    })
)

for chunk in generator:
    print(
        str(chunk, encoding = 'utf-8'), 
        end = "", 
        flush = True
    )
```

In [None]:
from sagemaker.serializers import JSONSerializer
predictors[-1].serializer = JSONSerializer()
predictors[-1].predict("Tell me a story about ducks.")

Let's just verify the Inference Components work on their own. We'll test the Llama IC with a synchronous invocation, and Mistral with streaming.

In [None]:
from sagemaker.predictor import Predictor
mistral_predictor = Predictor(endpoint_name=llama_mistral_endpoint_name, component_name=mistral_ic_name)
mistral_predictor.content_type = "application/json"
llama_predictor = Predictor(endpoint_name=llama_mistral_endpoint_name, component_name=llama_ic_name)
llama_predictor.content_type = "application/json"

In [None]:
import json
payload = {
    "inputs": "What is the capital of Japan?"
}

llama_predictor.predict(json.dumps(payload))


In [None]:
import boto3
import json

# Define the prompt and other parameters
prompt = """
<s>[INST] Below is the question based on the context. 
Question: Given a reference text about Lollapalooza, where does it take place, who started it and what is it?. 
Below is the given the context Lollapalooza /ˌlɒləpəˈluːzə/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. 
It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. 
The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States. Lollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addiction.. 
Write a response that appropriately completes the request.[/INST]
"""
 
max_tokens_to_sample = 200

# hyperparameters for llm
parameters = {
    "max_new_tokens": max_tokens_to_sample,
    "do_sample": True,
    "top_p": 0.9,
    "temperature": 0.5,
}

contentType = 'application/json'

body = json.dumps({
    "inputs": prompt,
    # specify the parameters as needed
    "parameters": parameters
})

mistral_predictor.content_type = contentType
for line in mistral_predictor.predict_stream(body):
    decoded_line = line.decode('utf-8')
    if '\n' in decoded_line:
        # Split by newline to handle multiple tokens in the same line
        tokens = decoded_line.split('\n')
        for token in tokens[:-1]:  # Print all tokens except the last one with a newline
            print(token)
        # Print the last token without a newline, as it might be followed by more tokens
        print(tokens[-1], end='')
    else:
        # Print the token without a newline if it doesn't contain '\n'
        print(decoded_line, end='')

### Cleanup the Resources

In [None]:
predictors[-1].delete_predictor()

In [None]:
mistral_predictor.delete_predictor()
llama_predictor.delete_predictor()
llama_predictor.delete_endpoint()