# SageMaker JumpStart Foundation Models - Inference Latency and Throughput Benchmarking

***
Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use SageMaker JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart).


In this demo notebook, we demonstrate how to run latency and throughput benchmarking analyses on a set of SageMaker JumpStart models. The structure of the notebook allows you to both benchmark a single model against multiple payloads and multiple models against a single payload. 

***

1. [Set up](#1.-Set-up)
2. [Run latency and throughput benchmarking](#2.-Run-latency-and-throughput-benchmarking)
3. [Visualize benchmarking results](#3.-Visualize-benchmarking-results)
4. [Clean up](#4.-Clean-up)

### 1. Set up

***
Before executing the notebook, there are some initial steps required for set up. 
***

In [None]:
%pip install --upgrade sagemaker ipywidgets --quiet

***
Here, you will query the SageMaker SDK to return a list of all HuggingFace text generation (and text2text) models hosted by SageMaker Model Hub. You can manually select any combination of these models to run benchmarking on with the Jupyter Widget produced in the output of this cell. By default, only a few models are selected.
***

In [None]:
from ipywidgets import SelectMultiple, Layout
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And, Or

# Retrieves all Text Generation models available by SageMaker Built-In Algorithms.
tasks = ["textgeneration", "textgeneration1", "textgeneration2", "text2text"]
filter_value = And(Or(*[f"task == {task}" for task in tasks]), "framework == huggingface")
text_models = list_jumpstart_models(filter=filter_value)
selected_text_models = [
    "huggingface-text2text-flan-t5-xxl",
    "huggingface-textgeneration1-gpt-j-6b",
    "huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16",
    "huggingface-textgeneration-bloom-1b7",
]
# if you would like to run on all JumpStart LLMs instead, uncomment the following line.
# selected_text_models = text_models.copy()
# selected_text_models.remove("huggingface-textgeneration1-bloom-176b-int8")
# selected_text_models.remove("huggingface-textgeneration1-bloomz-176b-fp16")

models_selection = SelectMultiple(
    options=text_models,
    value=selected_text_models,
    description="Models:",
    rows=25,
    layout=Layout(width="100%"),
)
display(models_selection)

***
In the following cell, you will select the models and payloads to benchmark. Every payload will be benchmarked against every model.
- **MODELS**: A list of SageMaker JumpStart model IDs to run benchmarking against.
- **PAYLOADS**: A dictionary with keys identifying a unique name for a query payload and values containing a valid payload dictionary.
***

In [None]:
MODELS = models_selection.value

PAYLOADS = {
    "simple_short_input": {
        "text_inputs": "Hello!",
        "do_sample": True,
    },
    "generate_summary": {
        "text_inputs": (
            "Write a short summary for this text: Amazon Comprehend uses natural language "
            "processing (NLP) to extract insights about the content of documents. It develops "
            "insights by recognizing the entities, key phrases, language, sentiments, and other "
            "common elements in a document. Use Amazon Comprehend to create new products based on "
            "understanding the structure of documents. For example, using Amazon Comprehend you "
            "can search social networking feeds for mentions of products or scan an entire "
            "document repository for key phrases. \nYou can access Amazon Comprehend document "
            "analysis capabilities using the Amazon Comprehend console or using the Amazon "
            "Comprehend APIs. You can run real-time analysis for small workloads or you can start "
            "asynchronous analysis jobs for large document sets. You can use the pre-trained "
            "models that Amazon Comprehend provides, or you can train your own custom models for "
            "classification and entity recognition. \nAll of the Amazon Comprehend features "
            "accept UTF-8 text documents as the input. In addition, custom classification and "
            "custom entity recognition accept image files, PDF files, and Word files as input. \n"
            "Amazon Comprehend can examine and analyze documents in a variety of languages, "
            "depending on the specific feature. For more information, see Languages supported in "
            "Amazon Comprehend. Amazon Comprehend's Dominant language capability can examine "
            "documents and determine the dominant language for a far wider selection of languages."
        ),
        "do_sample": True,
        "max_length": 500,
    },
}

***
The following set of constants drive the behavior of this notebook:
- **MAX_CONCURRENT_INVOCATIONS_PER_MODEL**: The maximum number of endpoint predictions to request concurrently.
- **MAX_CONCURRENT_BENCHMARKS**: The maximum number of models to concurrently benchmark.
- **RETRY_WAIT_TIME_SECONDS**: The amount of time in seconds to wait between Amazon CloudWatch queries. This is necessary because the endpoint emits CloudWatch metrics on a periodic interval, so we need to wait until all samples are emitted to CloudWatch before publishing benchmarking statistics.
- **MAX_TOTAL_RETRY_TIME_SECONDS**: The maximum amount of time in seconds to wait on Amazon CloudWatch emissions before proceeding without collecting the requested benchmarking metrics.
- **NUM_INVOCATIONS**: The number of endpoint predictions to request per benchmark.
- **SAVE_METRICS_FILE_PATH**: The JSON file used to save the resulting metrics.
- **SM_SESSION**: SageMaker Session object with custom configuration to resolve [SDK rate exceeded and throttling exceptions](https://aws.amazon.com/premiumsupport/knowledge-center/sagemaker-python-throttlingexception/).
***

In [None]:
from pathlib import Path

import boto3
from botocore.config import Config
from sagemaker.session import Session


MAX_CONCURRENT_INVOCATIONS_PER_MODEL = 30
MAX_CONCURRENT_BENCHMARKS = 50
RETRY_WAIT_TIME_SECONDS = 30.0
MAX_TOTAL_RETRY_TIME_SECONDS = 120.0
NUM_INVOCATIONS = 10
SAVE_METRICS_FILE_PATH = Path.cwd() / "latency_benchmarking.json"
SM_SESSION = Session(
    sagemaker_client=boto3.client(
        "sagemaker",
        config=Config(connect_timeout=5, read_timeout=60, retries={"max_attempts": 20}),
    )
)

### 2. Run latency and throughput benchmarking

***

The following block defines a function to run benchmarking on a single SageMaker JumpStart model ID. This function performs the following actions:
- Create a SageMaker JumpStart `Model` object.
- Deploy the Model and obtain a `Predictor`.
- Run all benchmarking load tests for each payload defined in the `PAYLOADS` dictionary. The benchmarking process includes:
  - Obtain latency statistics - serially invoke an endpoint to obtain a batch of predictions and utilize the Amazon CloudWatch [GetMetricStatistics](https://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_GetMetricStatistics.html) API to obtain latency statistics regarding the batch of predictions. The endpoint is invoked `NUM_INVOCATIONS` times.
  - Obtain throughput statistics - concurrently invoke an endpoint to obtain client-side throughput statistics. The endpoint is invoked `NUM_INVOCATIONS` times.
- Clean up predictor model and endpoint. If any errors occur during the benchmarking process for a given model, this clean up process still occurs prior to raising the error.

***

In [None]:
from typing import Any, Dict, List

from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.serializers import JSONSerializer
from sagemaker.utils import name_from_base

from benchmarking.load_test import run_benchmarking_load_tests
from benchmarking.load_test import logging_prefix


def run_benchmarking(model_id: str) -> List[Dict[str, Any]]:
    model = JumpStartModel(model_id=model_id, sagemaker_session=SM_SESSION)

    endpoint_name = name_from_base(f"jumpstart-bm-{model_id.replace('huggingface', 'hf')}")

    print(f"{logging_prefix(model_id)} Deploying endpoint {endpoint_name} ...")
    predictor = model.deploy(endpoint_name=endpoint_name)
    predictor.serializer = JSONSerializer()
    predictor.content_type = "application/json"

    metrics = []
    try:
        for payload_name, payload in PAYLOADS.items():
            metrics_payload = run_benchmarking_load_tests(
                predictor=predictor,
                payload=payload,
                model_id=model_id,
                payload_name=payload_name,
                num_invocations=NUM_INVOCATIONS,
                max_workers=MAX_CONCURRENT_INVOCATIONS_PER_MODEL,
                retry_wait_time=RETRY_WAIT_TIME_SECONDS,
                max_total_retry_time=MAX_TOTAL_RETRY_TIME_SECONDS,
            )
            metrics.append(metrics_payload)
    finally:
        print(f"{logging_prefix(model_id)} Cleaning up resources ...")
        predictor.delete_model()
        predictor.delete_endpoint()

    return metrics

***
In the following block, the `run_benchmarking` function is called for all model IDs specified within the previously defined `MODELS` list. To avoid a serial deployment process, the Python standard library [concurrent futures](https://docs.python.org/3/library/concurrent.futures.html) module is used to concurrently execute a `MAX_CONCURRENT_BENCHMARKS` number of executor threads. When a thread completes execution, the computed metrics are extended into a single list. If any thread raises an error instead of returning metrics, the errors are recorded in a dictionary without re-raising the error. This allows benchmarking to continue for all other models.
***

In [None]:
from concurrent import futures


metrics = []
benchmarking_error_dict = {}
with futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENT_BENCHMARKS) as executor:
    future_to_model_id = {
        executor.submit(run_benchmarking, model_id): model_id for model_id in MODELS
    }
    for future in futures.as_completed(future_to_model_id):
        model_id = future_to_model_id[future]
        try:
            metrics.extend(future.result())
        except Exception as e:
            benchmarking_error_dict[model_id] = e
            print(f"(Model {model_id}) Benchmarking failed: {e}")

***
Finally, we save these benchmarked metrics to a JSON file for use in downstream analyses.
***

In [None]:
import json


output = {"models": MODELS, "payloads": PAYLOADS, "metrics": metrics}
with open(SAVE_METRICS_FILE_PATH, "w") as file:
    json.dump(output, file, indent=4, ensure_ascii=False)

### 3. Visualize benchmarking results

***
The saved JSON results are now re-loaded into a normalized pandas DataFrame for visualization. This cell shows the following:
1. The column names of the DataFrame. These are the available statistics you are able to explore.
2. A table that shows a sample output from each model ID in `MODELS` for each payload in `PAYLOAD`.
3. A table that shows key latency and throughput statistics for each model ID in `MODELS` and each payload in `PAYLOAD`.
***

In [None]:
import pandas as pd


pd.set_option("display.max_colwidth", 0)
pd.set_option("display.max_rows", 500)

df = pd.json_normalize(metrics)
print("Here are the available statistics: ", list(df.columns))

index_cols = ["PayloadName", "ModelID"]
display_cols = ["PayloadName", "ModelID", "SampleOutput"]
sort_cols = ["PayloadName"]
display(df[display_cols].sort_values(by=sort_cols).set_index(index_cols))

display_cols = [
    "PayloadName",
    "ModelID",
    "Throughput",
    "ModelLatency.Average",
    "Client.Latency.Average",
    "Client.OutputSequenceWords.Average",
    "WordThroughput",
    "Client.LatencyPerOutputWord.Average",
]
sort_cols = ["PayloadName", "Client.LatencyPerOutputWord.Average"]
display(df[display_cols].sort_values(by=sort_cols).set_index(index_cols).round(3))

Here are the available statistics:  ['Throughput', 'WordThroughput', 'ModelID', 'PayloadName', 'SampleOutput', 'ModelLatency.SampleCount', 'ModelLatency.Average', 'ModelLatency.Minimum', 'ModelLatency.Maximum', 'ModelLatency.p50', 'ModelLatency.p90', 'ModelLatency.p95', 'OverheadLatency.SampleCount', 'OverheadLatency.Average', 'OverheadLatency.Minimum', 'OverheadLatency.Maximum', 'OverheadLatency.p50', 'OverheadLatency.p90', 'OverheadLatency.p95', 'Client.InputSequenceWords.Average', 'Client.InputSequenceWords.Minimum', 'Client.InputSequenceWords.Maximum', 'Client.InputSequenceWords.p50', 'Client.InputSequenceWords.p90', 'Client.InputSequenceWords.p95', 'Client.OutputSequenceWords.Average', 'Client.OutputSequenceWords.Minimum', 'Client.OutputSequenceWords.Maximum', 'Client.OutputSequenceWords.p50', 'Client.OutputSequenceWords.p90', 'Client.OutputSequenceWords.p95', 'Client.Latency.Average', 'Client.Latency.Minimum', 'Client.Latency.Maximum', 'Client.Latency.p50', 'Client.Latency.p90', 

Unnamed: 0_level_0,Unnamed: 1_level_0,SampleOutput
PayloadName,ModelID,Unnamed: 2_level_1
generate_summary,huggingface-text2text-flan-t5-xxl-fp16,Use Amazon Comprehend to analyze text documents in a variety of languages.
generate_summary,huggingface-text2text-bigscience-t0pp-bnb-int8,"Documents in UTF-8, PDF files, Word files, and image files"
generate_summary,huggingface-textgeneration-bloom-1b7,"When you choose Amazon Comprehend to analyze your documents, you can change Amazon Comprehend to other languages if you don't want your document analyzed using a dominant language.\n\nQuestions related to the above: \n\nIf a user has created Amazon Comprehend's document repository, the document repository can be used to run analysis jobs with Amazon Comprehend. If not, how can the documents in the document repository used to run Amazon Comprehend analysis?\nCan I run a dataset on my document repository using Amazon Comprehend? How and where I should start? I know I can import any document in the repository to run analysis, but I didn't get the training dataset, or any samples of the document in the repository that could be used to implement the algorithm.\nCan I start a new analysis job with a document repository already having input documents from a search? What should I do to include the documents (or subsets of them) I already have in the analysis job?\n\nPlease let me know if I'm missing some important details. Thanks!\n\nA:\n\nAmazon Comprehend documents are stored as base64 encoded byte arrays. If you convert input stream to bytes, then you can use that as the input stream for the analysis. However, if you have to convert"
generate_summary,huggingface-text2text-bigscience-t0pp,Use Amazon Comprehend to scan social networking feeds for mentions of specific products using Amazon Comprehend Spotlight.
generate_summary,huggingface-text2text-flan-t5-xxl,Finds common elements from a document by performing Natural Language Processing (NLP) of textual content.
generate_summary,huggingface-text2text-bigscience-t0pp-fp16,"Documents and other media (images, PDFs, Word files) Sentiment analysis"
generate_summary,huggingface-text2text-flan-t5-xxl-bnb-int8,Understand the structure of documents using Amazon Comprehend.
generate_summary,huggingface-text2text-flan-ul2-bf16,Describes the architecture of Amazon Comprephend.
generate_summary,huggingface-textgeneration1-gpt-neo-125m,"\n\nProperties\n\nAmazon Comprehends the Language in the Document in the Document:\n[https://aws.amazon.com/comprehend/](https://aws.amazon.com/comprehend/)\nAmazon Comprehend's Language\n[https://developer.amazon.com/comprehends/latest/feature/compularnglbl/?hl=en-US,pct=en-GB&ioc='amazon-comprehend']\n\nA:\n\nI use the Amazon Comprehend API in a couple of situations in my work as a project manager for Amazon Ecosystem team.\nThey have a collection called Resources for your project, which you can interact with as you work. Also you can create a Resource, which contains your user-defined resources. Here is a link for a related project.\nNow, these are the APIs that you can use to make your project's resources work. In case you are using a library that provides a library that requires some data, they can use it to provide the access that you need to your library.\n\n"
generate_summary,huggingface-textgeneration1-bloom-7b1,


Unnamed: 0_level_0,Unnamed: 1_level_0,Throughput,ModelLatency.Average,Client.Latency.Average,Client.OutputSequenceWords.Average,WordThroughput,Client.LatencyPerOutputWord.Average
PayloadName,ModelID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
generate_summary,huggingface-textgeneration1-gpt-neo-125m,1.823,0.473,0.535,109.7,199.965,0.006
generate_summary,huggingface-textgeneration1-gpt-neo-125m-fp16,2.062,0.416,0.474,102.1,210.496,0.006
generate_summary,huggingface-textgeneration1-gpt-neo-1-3b-fp16,0.554,1.669,1.726,147.2,81.576,0.012
generate_summary,huggingface-text2text-qcpg-sentences,0.801,0.886,0.948,62.2,49.813,0.015
generate_summary,huggingface-textgeneration1-gpt-neo-2-7b-fp16,0.459,2.667,2.73,173.2,79.54,0.016
generate_summary,huggingface-textgeneration1-gpt-2-xl-fp16,0.414,1.729,1.787,128.8,53.285,0.018
generate_summary,huggingface-textgeneration1-gpt-neo-1-3b,0.468,2.754,2.811,155.6,72.866,0.018
generate_summary,huggingface-text2text-flan-t5-small,5.224,0.139,0.197,10.1,52.759,0.02
generate_summary,huggingface-text2text-bart4csc-base-chinese,0.822,1.324,1.383,74.7,61.433,0.021
generate_summary,huggingface-textgeneration-gpt2,0.373,3.488,3.548,174.2,64.945,0.021


***
Finally, we show some plots based on this latency analysis. For each payload, this cell creates a plotly figure that plots the average latency per output word versus word throughput, or the number of words in output sequences returned per second by the model. In general, throughput = 1 / latency. However, multi-model endpoints and load-balanced endpoints can improve throughput for a fixed latency. Both of these are important metrics to consider when designing requirements for model selection.
***

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np


for payload_name in PAYLOADS:
    col_x, col_y = "WordThroughput", "Client.LatencyPerOutputWord.Average"
    df_plot = df[df["PayloadName"] == payload_name]
    fig = px.scatter(df_plot, x=col_x, y=col_y, hover_data=["ModelID"])
    fig.add_trace(
        go.Scatter(x=np.linspace(1, 300, 300), y=1 / np.linspace(1, 300, 300), name="y=1/x")
    )
    fig.update_layout(
        xaxis_range=[0.0, df_plot[col_x].max() * 1.1],
        yaxis_range=[0.0, df_plot[col_y].max() * 1.1],
        title=f"Latency per word vs. word throughput for payload {payload_name}",
    )
    fig.show()

### 4. Clean up

***
When you are done with the endpoints, you should delete them to avoid additional costs. In this demonstration, clean up occurs at the end of each individual benchmarking job.
***