# SageMaker JumpStart Foundation Models - Inference Latency and Throughput Benchmarking

***
Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use SageMaker JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart).


In this demo notebook, we demonstrate how to run latency and throughput benchmarking analyses on a set of SageMaker JumpStart models. The structure of the notebook allows you to both benchmark a single model against multiple payloads and multiple models against a single payload. 

***

1. [Set up](#1.-Set-up)
2. [Run latency and throughput benchmarking](#2.-Run-latency-and-throughput-benchmarking)
3. [Visualize benchmarking results](#3.-Visualize-benchmarking-results)
4. [Clean up](#4.-Clean-up)

### 1. Set up

***
Before executing the notebook, there are some initial steps required for set up. 
***

In [None]:
%pip install --upgrade sagemaker ipywidgets --quiet

***
Here, you will query the SageMaker SDK to return a list of all HuggingFace text generation (and text2text) models hosted by SageMaker Model Hub. You can manually select any combination of these models to run benchmarking on with the Jupyter Widget produced in the output of this cell. By default, only a few models are selected.
***

In [None]:
from ipywidgets import SelectMultiple, Layout
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And, Or

# Retrieves all Text Generation models available by SageMaker Built-In Algorithms.
tasks = ["textgeneration", "textgeneration1", "textgeneration2", "text2text"]
filter_value = And(Or(*[f"task == {task}" for task in tasks]), "framework == huggingface")
text_models = list_jumpstart_models(filter=filter_value)
selected_text_models = [
    "huggingface-text2text-flan-t5-xxl",
    "huggingface-textgeneration1-gpt-j-6b",
    "huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16",
    "huggingface-textgeneration-bloom-1b7",
]
# if you would like to run on all JumpStart LLMs instead, uncomment the following line.
# selected_text_models = text_models.copy()
# selected_text_models.remove("huggingface-textgeneration1-bloom-176b-int8")
# selected_text_models.remove("huggingface-textgeneration1-bloomz-176b-fp16")

models_selection = SelectMultiple(
    options=text_models,
    value=selected_text_models,
    description="Models:",
    rows=25,
    layout=Layout(width="100%"),
)
display(models_selection)

***
In the following cell, you will select the models and payloads to benchmark. Every payload will be benchmarked against every model.
- **MODELS**: A list of SageMaker JumpStart model IDs to run benchmarking against.
- **PAYLOADS**: A dictionary with keys identifying a unique name for a query payload and values containing a valid payload dictionary.
***

In [None]:
MODELS = models_selection.value

PAYLOADS = {
    "simple_short_input": {
        "text_inputs": "Hello!",
        "do_sample": True,
    },
    "generate_summary": {
        "text_inputs": (
            "Write a short summary for this text: Amazon Comprehend uses natural language "
            "processing (NLP) to extract insights about the content of documents. It develops "
            "insights by recognizing the entities, key phrases, language, sentiments, and other "
            "common elements in a document. Use Amazon Comprehend to create new products based on "
            "understanding the structure of documents. For example, using Amazon Comprehend you "
            "can search social networking feeds for mentions of products or scan an entire "
            "document repository for key phrases. \nYou can access Amazon Comprehend document "
            "analysis capabilities using the Amazon Comprehend console or using the Amazon "
            "Comprehend APIs. You can run real-time analysis for small workloads or you can start "
            "asynchronous analysis jobs for large document sets. You can use the pre-trained "
            "models that Amazon Comprehend provides, or you can train your own custom models for "
            "classification and entity recognition. \nAll of the Amazon Comprehend features "
            "accept UTF-8 text documents as the input. In addition, custom classification and "
            "custom entity recognition accept image files, PDF files, and Word files as input. \n"
            "Amazon Comprehend can examine and analyze documents in a variety of languages, "
            "depending on the specific feature. For more information, see Languages supported in "
            "Amazon Comprehend. Amazon Comprehend's Dominant language capability can examine "
            "documents and determine the dominant language for a far wider selection of languages."
        ),
        "do_sample": True,
        "max_length": 500,
    },
}

***
The following set of constants drive the behavior of this notebook:
- **MAX_CONCURRENT_INVOCATIONS_PER_MODEL**: The maximum number of endpoint predictions to request concurrently.
- **MAX_CONCURRENT_BENCHMARKS**: The maximum number of models to concurrently benchmark.
- **RETRY_WAIT_TIME_SECONDS**: The amount of time in seconds to wait between Amazon CloudWatch queries. This is necessary because the endpoint emits CloudWatch metrics on a periodic interval, so we need to wait until all samples are emitted to CloudWatch before publishing benchmarking statistics.
- **MAX_TOTAL_RETRY_TIME_SECONDS**: The maximum amount of time in seconds to wait on Amazon CloudWatch emissions before proceeding without collecting the requested benchmarking metrics.
- **NUM_INVOCATIONS**: The number of endpoint predictions to request per benchmark.
- **SAVE_METRICS_FILE_PATH**: The JSON file used to save the resulting metrics.
- **SM_SESSION**: SageMaker Session object with custom configuration to resolve [SDK rate exceeded and throttling exceptions](https://aws.amazon.com/premiumsupport/knowledge-center/sagemaker-python-throttlingexception/).
***

In [None]:
from pathlib import Path

import boto3
from botocore.config import Config
from sagemaker.session import Session


MAX_CONCURRENT_INVOCATIONS_PER_MODEL = 30
MAX_CONCURRENT_BENCHMARKS = 50
RETRY_WAIT_TIME_SECONDS = 30.0
MAX_TOTAL_RETRY_TIME_SECONDS = 120.0
NUM_INVOCATIONS = 10
SAVE_METRICS_FILE_PATH = Path.cwd() / "latency_benchmarking.json"
SM_SESSION = Session(
    sagemaker_client=boto3.client(
        "sagemaker",
        config=Config(connect_timeout=5, read_timeout=60, retries={"max_attempts": 20}),
    )
)

### 2. Run latency and throughput benchmarking

***

The following block defines a function to run benchmarking on a single SageMaker JumpStart model ID. This function performs the following actions:
- Create a SageMaker JumpStart `Model` object.
- Deploy the Model and obtain a `Predictor`.
- Run all benchmarking load tests for each payload defined in the `PAYLOADS` dictionary. The benchmarking process includes:
  - Obtain latency statistics - serially invoke an endpoint to obtain a batch of predictions and utilize the Amazon CloudWatch [GetMetricStatistics](https://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_GetMetricStatistics.html) API to obtain latency statistics regarding the batch of predictions. The endpoint is invoked `NUM_INVOCATIONS` times.
  - Obtain throughput statistics - concurrently invoke an endpoint to obtain client-side throughput statistics. The endpoint is invoked `NUM_INVOCATIONS` times.
- Clean up predictor model and endpoint. If any errors occur during the benchmarking process for a given model, this clean up process still occurs prior to raising the error.

***

In [None]:
from typing import Any, Dict, List

from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.serializers import JSONSerializer
from sagemaker.utils import name_from_base

from benchmarking.load_test import run_benchmarking_load_tests
from benchmarking.load_test import logging_prefix


def run_benchmarking(model_id: str) -> List[Dict[str, Any]]:
    model = JumpStartModel(model_id=model_id, sagemaker_session=SM_SESSION)

    endpoint_name = name_from_base(f"jumpstart-bm-{model_id.replace('huggingface', 'hf')}")

    print(f"{logging_prefix(model_id)} Deploying endpoint {endpoint_name} ...")
    predictor = model.deploy(endpoint_name=endpoint_name)
    predictor.serializer = JSONSerializer()
    predictor.content_type = "application/json"

    metrics = []
    try:
        for payload_name, payload in PAYLOADS.items():
            metrics_payload = run_benchmarking_load_tests(
                predictor=predictor,
                payload=payload,
                model_id=model_id,
                payload_name=payload_name,
                num_invocations=NUM_INVOCATIONS,
                max_workers=MAX_CONCURRENT_INVOCATIONS_PER_MODEL,
                retry_wait_time=RETRY_WAIT_TIME_SECONDS,
                max_total_retry_time=MAX_TOTAL_RETRY_TIME_SECONDS,
            )
            metrics.append(metrics_payload)
    finally:
        print(f"{logging_prefix(model_id)} Cleaning up resources ...")
        predictor.delete_model()
        predictor.delete_endpoint()

    return metrics

***
In the following block, the `run_benchmarking` function is called for all model IDs specified within the previously defined `MODELS` list. To avoid a serial deployment process, the Python standard library [concurrent futures](https://docs.python.org/3/library/concurrent.futures.html) module is used to concurrently execute a `MAX_CONCURRENT_BENCHMARKS` number of executor threads. When a thread completes execution, the computed metrics are extended into a single list. If any thread raises an error instead of returning metrics, the errors are recorded in a dictionary without re-raising the error. This allows benchmarking to continue for all other models.
***

In [None]:
from concurrent import futures


metrics = []
benchmarking_error_dict = {}
with futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENT_BENCHMARKS) as executor:
    future_to_model_id = {
        executor.submit(run_benchmarking, model_id): model_id for model_id in MODELS
    }
    for future in futures.as_completed(future_to_model_id):
        model_id = future_to_model_id[future]
        try:
            metrics.extend(future.result())
        except Exception as e:
            benchmarking_error_dict[model_id] = e
            print(f"(Model {model_id}) Benchmarking failed: {e}")

***
Finally, we save these benchmarked metrics to a JSON file for use in downstream analyses.
***

In [None]:
import json


output = {"models": MODELS, "payloads": PAYLOADS, "metrics": metrics}
with open(SAVE_METRICS_FILE_PATH, "w") as file:
    json.dump(output, file, indent=4, ensure_ascii=False)

### 3. Visualize benchmarking results

***
The saved JSON results are now re-loaded into a normalized pandas DataFrame for visualization. This cell shows the following:
1. The column names of the DataFrame. These are the available statistics you are able to explore.
2. A table that shows a sample output from each model ID in `MODELS` for each payload in `PAYLOAD`.
3. A table that shows key latency and throughput statistics for each model ID in `MODELS` and each payload in `PAYLOAD`.
***

In [None]:
import pandas as pd


pd.set_option("display.max_colwidth", 0)
pd.set_option("display.max_rows", 500)

df = pd.json_normalize(metrics)
print("Here are the available statistics: ", list(df.columns))

index_cols = ["PayloadName", "ModelID"]
display_cols = ["PayloadName", "ModelID", "SampleOutput"]
sort_cols = ["PayloadName"]
display(df[display_cols].sort_values(by=sort_cols).set_index(index_cols))

display_cols = [
    "PayloadName",
    "ModelID",
    "Throughput",
    "ModelLatency.Average",
    "Client.Latency.Average",
    "Client.OutputSequenceWords.Average",
    "WordThroughput",
    "Client.LatencyPerOutputWord.Average",
]
sort_cols = ["PayloadName", "Client.LatencyPerOutputWord.Average"]
display(df[display_cols].sort_values(by=sort_cols).set_index(index_cols).round(3))

Here are the available statistics:  ['Throughput', 'WordThroughput', 'ModelID', 'PayloadName', 'SampleOutput', 'ModelLatency.SampleCount', 'ModelLatency.Average', 'ModelLatency.Minimum', 'ModelLatency.Maximum', 'ModelLatency.p50', 'ModelLatency.p90', 'ModelLatency.p95', 'OverheadLatency.SampleCount', 'OverheadLatency.Average', 'OverheadLatency.Minimum', 'OverheadLatency.Maximum', 'OverheadLatency.p50', 'OverheadLatency.p90', 'OverheadLatency.p95', 'Client.InputSequenceWords.Average', 'Client.InputSequenceWords.Minimum', 'Client.InputSequenceWords.Maximum', 'Client.InputSequenceWords.p50', 'Client.InputSequenceWords.p90', 'Client.InputSequenceWords.p95', 'Client.OutputSequenceWords.Average', 'Client.OutputSequenceWords.Minimum', 'Client.OutputSequenceWords.Maximum', 'Client.OutputSequenceWords.p50', 'Client.OutputSequenceWords.p90', 'Client.OutputSequenceWords.p95', 'Client.Latency.Average', 'Client.Latency.Minimum', 'Client.Latency.Maximum', 'Client.Latency.p50', 'Client.Latency.p90', 

Unnamed: 0_level_0,Unnamed: 1_level_0,SampleOutput
PayloadName,ModelID,Unnamed: 2_level_1
generate_summary,huggingface-textgeneration-bloom-1b7,"This capability can be enabled by setting the Amazon Comprehend Language option. \nOther options include language-specific keyphrase expansion, word count processing, sentence normalization, word error repair, content similarity based similarity detection, syntax analysis, and more.\n\nAmazon Comprehend is written in Python, allowing for rapid development and deployment. Amazon Comprehend can be downloaded as either a Java library or a library for the AWS SDK for Python 1.5. You can install the SDK using the Python SDK on Google Play™. The official documentation for Amazon Comprehend in the AWS SDK for Python 1.5 is available here.\n\nAmazon Comprehend integrates with various cloud-based services and is available as a service. Here are some examples:\n\nAmazon CloudFront, the Amazon CDN (Content Delivery Network) is used behind the scenes to deliver documents to users. It offers more efficient delivery of large documents and allows them to serve the documents on-demand rather than be served by Amazon's traditional back-end system. In addition, CloudFront allows the services on which Amazon Comprehend depends, such as Amazon S3 and CloudFlare, to process large amounts of data faster. To get started with the Amazon CloudFront service, see How"
generate_summary,huggingface-text2text-flan-t5-xxl,"Examine, analyze, and predict the outcome of documents and words with Amazon Comprehend."
generate_summary,huggingface-textgeneration1-gpt-j-6b,"\nAmazon has a great deal of other AI features such as image summarization and image and language-specific models to classify text files, document to content analysis, entity recognition and more advanced machine learning tasks. With Amazon Comprehend you can use the Comprehend APIs, Amazon Comprehend can take the documents input formats as simple as MS Word documents, HTML documents, PDF files. Amazon Comprehend can process up to 1 million documents a day.\nAmazon Comprehend has you can perform various analytics tasks on the documents. This is the language as Amazon Comprehend features that can be deployed in a large scale. You can run real-time document analysis. This is a scalable cost-effective, powerful, cloud-based services Amazon Comprehend, Inc.\n\nYou can use as a part of an API, Amazon Comprehend also uses machine learning models for specific tasks.\nWe are glad you're using any of Amazon Comprehend API can be used to identify new languages in the services by integrating with. Comprehend can also analyze content in Amazon Comprehend will accept UTF-8 format for your own data into JSON and Amazon Comprehend has a"
generate_summary,huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16,"When there is no dominant language in which the documents were written, the Dominant Language value is Not Available. \nI'm going to let you in a little secret. I have a bit of a problem with the title. And what we have been talking about as far as the “futurism in the arts” is an aesthetics of the future. This is when we imagine new aesthetics for the world to come into a recognizable form of aesthetic art, something that was not the case in the postwar era.\n<human>: Add another sentence about Amazon, Comprehend, document, processing.\n<bot>: If you can't find a title for this tweet, that means it was probably never tweeted in the first place. \n## Amazon AWS: Document Classification Models.\nThe use of machine learning to analyze web documents is one of the latest developments in “big data” technology. As part of the Amazon Web Services (AWS) family, Google Cloud offers three other options:Cloud Speech API Analyzes Speech for Sentiment and IntentAmazon Lex provides a platform"
simple_short_input,huggingface-textgeneration-bloom-1b7,"Goodbye!'\n\n'Will you come up?\n\n'Why, no,"" said"
simple_short_input,huggingface-text2text-flan-t5-xxl,Hello there! It's me of -BlondiePao-:
simple_short_input,huggingface-textgeneration1-gpt-j-6b,"I'm new here, if you have no idea what this is\nabout, you should check out the FAQ in the top right hand navigation bar for some info\nas to what this community is about.. As well as being a member of"
simple_short_input,huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16,My name is Josh and I'll be helping you out today.\n\n**What kind


Unnamed: 0_level_0,Unnamed: 1_level_0,Throughput,ModelLatency.Average,Client.Latency.Average,Client.OutputSequenceWords.Average,WordThroughput,Client.LatencyPerOutputWord.Average
PayloadName,ModelID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
generate_summary,huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16,0.176,5.69,5.77,172.1,30.221,0.034
generate_summary,huggingface-textgeneration-bloom-1b7,0.244,3.548,3.61,119.4,29.167,0.034
generate_summary,huggingface-textgeneration1-gpt-j-6b,0.483,5.138,5.195,100.9,48.687,0.055
generate_summary,huggingface-text2text-flan-t5-xxl,0.439,1.685,1.743,7.9,3.464,0.226
simple_short_input,huggingface-textgeneration1-gpt-j-6b,1.494,1.482,1.547,37.7,56.334,0.041
simple_short_input,huggingface-textgeneration-bloom-1b7,2.244,0.515,0.604,14.9,33.43,0.042
simple_short_input,huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16,2.121,0.546,0.61,11.1,23.542,0.056
simple_short_input,huggingface-text2text-flan-t5-xxl,1.086,1.055,1.12,8.8,9.553,0.365


***
Finally, we show some plots based on this latency analysis. For each payload, this cell creates a plotly figure that plots the average latency per output word versus word throughput, or the number of words in output sequences returned per second by the model. In general, throughput = 1 / latency. However, multi-model endpoints and load-balanced endpoints can improve throughput for a fixed latency. Both of these are important metrics to consider when designing requirements for model selection.
***

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np


for payload_name in PAYLOADS:
    col_x, col_y = "WordThroughput", "Client.LatencyPerOutputWord.Average"
    df_plot = df[df["PayloadName"] == payload_name]
    fig = px.scatter(df_plot, x=col_x, y=col_y, hover_data=["ModelID"])
    fig.add_trace(
        go.Scatter(x=np.linspace(1, 300, 300), y=1 / np.linspace(1, 300, 300), name="y=1/x")
    )
    fig.update_layout(
        xaxis_range=[0.0, df_plot[col_x].max() * 1.1],
        yaxis_range=[0.0, df_plot[col_y].max() * 1.1],
        title=f"Latency per word vs. word throughput for payload {payload_name}",
    )
    fig.show()

### 4. Clean up

***
When you are done with the endpoints, you should delete them to avoid additional costs. In this demonstration, clean up occurs at the end of each individual benchmarking job.
***