# SageMaker JumpStart TensorFlow Image Classification Benchmarking

In [None]:
%pip install --upgrade boto3 sagemaker

In [None]:
import concurrent.futures as cf
import json
import itertools
import queue
import time
import traceback
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, List, NamedTuple, Optional, Tuple

import boto3
import pandas as pd
import sagemaker
import sagemaker.hyperparameters
import sagemaker.model_uris
import sagemaker.script_uris
import sagemaker.image_uris
from botocore.config import Config

***
Let's fist identify some top-level constants that will be utilized throughout this notebook. These constants are gathered at the top of this notebook so adjustments can be made in one place.

The first set of constants control SageMaker training job behaviors:
- __EC2_INSTANCE_TYPE__: EC2 instance type used for training.
- __SM_AMT_MAX_JOBS__: Maximum total number of training jobs to start per hyperparameter tuning job.
- __SM_AMT_MAX_PARALLEL_TRAINING_JOBS_PER_TUNER__: Maximum number of parallel training jobs per hyperparameter tuning job.
- __SM_AMT_MAX_PARALLEL_TUNING_JOBS__: Maximum number of parallel hyperparameter tuning jobs.
- __SM_AMT_OBJECTIVE_METRIC_NAME__: Name of the metric for evaluating training jobs.
- __SM_SESSION__: SageMaker Session object with custom configuration to resolve [SDK rate exceeded and throttling exceptions](https://aws.amazon.com/premiumsupport/knowledge-center/sagemaker-python-throttlingexception/).

The next set of constants control the behavior of the training script:
- __HYPERPARAMETERS__: Set of hyperparameters overriding any [default built-in value](https://docs.aws.amazon.com/sagemaker/latest/dg/IC-TF-Hyperparameter.html).

Finally, this notebook provides features to re-attach previously launched training jobs and load previously saved metrics for further analysis. The following constants control this behavior:
- __SAVE_TUNING_JOB_NAMES_FILE_PATH__: Path of the [JSON Lines](https://jsonlines.org/) file that keeps track of the tuning job name assoicated with a unique model name and dataset name.
- __SAVE_METRICS_FILE_PATH__: Path of the JSON Lines file that records metrics associated with each tuning job.
***

In [None]:
EC2_INSTANCE_TYPE = "ml.g4dn.xlarge"
SM_AMT_MAX_JOBS = 2
SM_AMT_MAX_PARALLEL_TRAINING_JOBS_PER_TUNER = 2
SM_AMT_MAX_PARALLEL_TUNING_JOBS = 20
SM_AMT_OBJECTIVE_METRIC_NAME = "val_accuracy"
SM_SESSION = sagemaker.Session(
    sagemaker_client=boto3.client(
        "sagemaker",
        config=Config(
            connect_timeout=5,
            read_timeout=60,
            retries={"max_attempts": 20}
        )
    )
)

HYPERPARAMETERS = {
    "epochs": 10,
    "early_stopping": "True",
    "early_stopping_patience": 3,
    "early_stopping_min_delta": 0.001,
    "augmentation": "False",  # For now, augmentation is CPU-bound and slow
}

SAVE_TUNING_JOB_NAMES_FILE_PATH = Path.cwd() / "benchmarking_tuning_job_names.jsonl"
SAVE_METRICS_FILE_PATH = Path.cwd() / "benchmarking_metrics.jsonl"

## Identify models and datasets
***
In this section, we will define two lists, `models` and `datasets`, which contain unique identifiers for all models and all datasets we wish to perform this benchmarking task on. The hyperparameter tuning jobs attempted to be instantiated will be the [cartesian product](https://docs.python.org/3/library/itertools.html#itertools.product) between these two lists.

***
First, we will identify all built-in image classification model IDs to run this benchmarking task on. Because SageMaker JumpStart maintains a large number of models for this task, the default code in this notebook identifies only a few models by model ID. If desired, it is possible to run a thorough benchmarking analysis on all TensorFlow image classification models made available by SageMaker Built-In Algorithms via:
```python
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And

filter_value = And("task == ic", "framework == tensorflow")
models = list_jumpstart_models(filter=filter_value)
```
This may be desired if you have a unique dataset and would like to perform large-scale benchmarking or model selection tasks on your custom dataset. However, please be cautious as a benchmarking task with this many models will require the deployment of a large number of resources.
***

In [None]:
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And

# Retrieves all TensorFlow Image Classification models made available by SageMaker Built-In Algorithms.
filter_value = And("task == ic", "framework == tensorflow")
models = list_jumpstart_models(filter=filter_value)
models.remove("tensorflow-ic-cait-m48-448")
models.remove("tensorflow-ic-cait-m36-384")
models.remove("tensorflow-ic-cait-s36-384")
models.remove("tensorflow-ic-bit-s-r152x4-ilsvrc2012")
models.remove("tensorflow-ic-bit-m-r152x4-ilsvrc2012")
models.remove("tensorflow-ic-bit-m-r152x4-imagenet21k")
models.remove("tensorflow-ic-bit-s-r101x3-ilsvrc2012-classification-1")
models.remove("tensorflow-ic-bit-m-r101x3-ilsvrc2012-classification-1")
models.remove("tensorflow-ic-bit-m-r101x3-imagenet21k-classification-1")
models.remove("tensorflow-ic-efficientnet-v2-imagenet21k-xl")
models.remove("tensorflow-ic-efficientnet-v2-imagenet21k-ft1k-xl")

# models = ic_models[:8]#[::20]
# models = ['tensorflow-ic-imagenet-mobilenet-v1-075-224-classification-4']
# models = ['tensorflow-ic-swin-s3-tiny-224']
# models = ['tensorflow-ic-swin-base-patch4-window7-224']
# models = ['tensorflow-ic-swin-small-patch4-window7-224']
# models = ['tensorflow-ic-imagenet-mobilenet-v1-075-224-classification-4', 'tensorflow-ic-efficientnet-v2-imagenet1k-b3', 'tensorflow-ic-efficientnet-v2-imagenet21k-l', 'tensorflow-ic-deit-base-patch16-384', 'tensorflow-ic-cait-s24-224', 'tensorflow-ic-cait-m48-448'] # 'tensorflow-ic-swin-large-patch4-window12-384', 'tensorflow-ic-swin-small-patch4-window7-224', 'tensorflow-ic-deit-base-patch16-224',
# models = ['tensorflow-ic-cait-m48-448', 'tensorflow-ic-swin-large-patch4-window12-384', 'tensorflow-ic-cait-m36-384']
models = ['tensorflow-ic-imagenet-mobilenet-v1-075-224-classification-4', 'tensorflow-ic-swin-small-patch4-window7-224']
print(len(models))
pprint(models)

***
We also need to identify the datasets to perform benchmarking on. Unlike built-in models, JumpStart does not provide a default API to query available datasets. It is also likely that you may have your own dataset hosted on S3 that you would like to benchmark. The following data structures provide a consistent framework to define dataset locations for the scope of this notebook. This is important because the benchmarking task is most beneficial with a training/validation/test dataset split. While possible, it is not recommended to let the model transfer learning script perform this split. Fitting a SageMaker Estimator requires channel definitions, which these objects create automatically via the `S3DatasetSplit.channels` method.

*Notes on dataset channel behaviors*: Training will utilize only the data provided in the "training" channel, model selection across hyperparameters and epochs will use data in the "validation" channel, and the final evaluation of model performance will be based on data provided in the "test" channel. If a "test" channel is not provided, then training should complete successfully, but metric definitions with a name matching the pattern "test_\*" will not be available in the training job logs. If a "validation" channel is not provided, then the default behavior of the JumpStart TensorFlow Image Classification algorithm is to perform a split of the "training" channel dataset into training and validation datasets.
***

In [None]:
class S3Dataset:
    def __init__(self, bucket: str, prefix: str) -> None:
        self.bucket = bucket
        self.prefix = prefix
        
    def path(self) -> str:
        return f"s3://{self.bucket}/{self.prefix}"


class S3DatasetSplit:
    def __init__(self, train: S3Dataset, validation: Optional[S3Dataset] = None, test: Optional[S3Dataset] = None) -> None:
        self.train = train
        self.validation = validation
        self.test = test
    
    @classmethod
    def from_prefixes(cls, bucket: str, prefix_train: str, prefix_validation: Optional[str] = None, prefix_test: Optional[str] = None) -> 'S3TrainValTestDataset':
        return cls(
            S3Dataset(bucket, prefix_train),
            S3Dataset(bucket, prefix_validation) if prefix_validation is not None else None,
            S3Dataset(bucket, prefix_test) if prefix_test is not None else None
        )
    
    def channels(self) -> Dict[str, str]:
        res = {"training": self.train.path()}
        if self.validation is not None:
            res["validation"] = self.validation.path()
        if self.test is not None:
            res["test"] = self.test.path()
        return res

***
Next, a dictionary of available datasets is created and one of these datasets is selected to perform analysis. To get a feel for the performance of different models with respect to different datasets, simply run this notebook for a different selected list of datasets! If you have your own dataset, just create a new entry that specifies the bucket along with prefixes for the train, validation, and test datasets. The dataset should be structured according to the [built-in algorithm training data input format](https://docs.aws.amazon.com/sagemaker/latest/dg/image-classification-tensorflow.html).
***

In [None]:
DATASET_DICT = {
    "tf-flowers": S3DatasetSplit.from_prefixes(
        bucket=f"jumpstart-cache-prod-{SM_SESSION.boto_region_name}",
        prefix_train="training-datasets/tf_flowers/",
        prefix_test="training-datasets/tf_flowers/"
    ),
    "ants-and-bees": S3DatasetSplit.from_prefixes(
        bucket=f"jumpstart-cache-prod-{SM_SESSION.boto_region_name}",
        prefix_train="training-datasets/ants-and-bees/",
        prefix_test="training-datasets/ants-and-bees/"
    ),
}

datasets = ["tf-flowers", "ants-and-bees"]

## Create assets for training

***
This section contains a variety of helper functions that will be utilized for this SageMaker TensorFlow image classification benchmarking task, including functions to:
1) create a SageMaker `Estimator` object from the JumpStart model hub
2) create a SageMaker`HyperparameterTuner` for a specified model from the JumpStart model hub
3) re-attach a SageMaker `HyperparameterTuner` if a tuning job has already started
4) extract metrics from `Estimator` logs
5) save tuning job information to file to enable re-attaching jobs in new sessions
6) save resulting benchmarking metrics to file

***
The following block contains a helper function to obtain a SageMaker `Estimator` for a given JumpStart built-in `model_id`. This includes obtaining the appropriate URIs for the training docker image, the training script tarball, and the pre-trained model tarball to further fine-tune. This retrieval is provided by the SageMaker JumpStart built-in algorithms and allows for the creation of a SageMaker `Estimator` instance directly from these URIs.
***

In [None]:
def create_jumpstart_estimator(
    model_id: str,
    role: str,
    job_name: str,
    s3_output_location: str,
    model_version: str,
    instance_type: str = EC2_INSTANCE_TYPE,
    metric_definitions: Optional[List[Dict[str, str]]] = None,
    hyperparameters: Optional[Dict[str, str]] = None,
) -> sagemaker.estimator.Estimator:
    """Obtain a SageMaker Estimator for a given model ID."""
    
    # Retrieve the docker image
    train_image_uri = sagemaker.image_uris.retrieve(
        region=None,
        framework=None,
        model_id=model_id,
        model_version=model_version,
        image_scope="training",
        instance_type=instance_type,
    )
    
    # Retrieve the training script
    train_source_uri = sagemaker.script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="training"
    )
    
    # Retrieve the pre-trained model tarball to further fine-tune
    train_model_uri = sagemaker.model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="training"
    )
    
    # Create and return SageMaker Estimator instance
    return sagemaker.estimator.Estimator(
        role=role,
        image_uri=train_image_uri,
        source_dir=train_source_uri,
        model_uri=train_model_uri,
        entry_point="transfer_learning.py",
        instance_count=1,
        instance_type=instance_type,
        max_run=360000,
        hyperparameters=hyperparameters,
        output_path=s3_output_location,
        base_job_name=job_name,
        metric_definitions=metric_definitions,
        sagemaker_session=SM_SESSION
    )

***
While we now have a means to create a SageMaker `Estimator`, default hyperparameter values may not be optimal for the considered task. Therefore, to obtain the best benchmarking results, we would like to wrap this `Estimator` within a SageMaker hyperparameter tuning job. The following function wraps `create_jumpstart_estimator` to obtain a SageMaker `HyperparameterTuner` for a given JumpStart build-in model_id with properties for this benchmarking task. Because tuning jobs have a 32-character name length limit and this benchmarking task can create a large number of tuning jobs with similar (or identical) names after truncation, a unique-id is provided for each model to enforce unique tuning job names.
***

In [None]:
def create_benchmarking_tuner(
    model_id: str,
    unique_id: int,
    session: sagemaker.session.Session = SM_SESSION,
    model_version: str = "*",
    objective_type: str = "Maximize",
) -> sagemaker.tuner.HyperparameterTuner:
    """Obtain a SageMaker HyperparameterTuner with properties for this benchmarking task.
    
    A unique ID is helpful to distinguish names of benchmarking jobs.
    """
    role = session.get_caller_identity_arn()
    output_bucket = session.default_bucket()
    output_prefix = "jumpstart-example-tf-ic-benchmarking"
    job_name = sagemaker.utils.name_from_base(f"bm-{unique_id}-{model_id.replace('tensorflow-ic-', '')}")
    s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"
    
    metrics_multiclass = ("top_5_accuracy",)
    metrics_binary = ("precision", "recall", "auc", "prc",)
    metrics = ("accuracy", "loss", *metrics_multiclass, *metrics_binary)
    metric_definitions = [
        *({"Name": f"train_{metric}", "Regex": f"- {metric}: ([0-9\\.]+)"} for metric in metrics),
        *({"Name": f"val_{metric}", "Regex": f"- val_{metric}: ([0-9\\.]+)"} for metric in metrics),
        *({"Name": f"test_{metric}", "Regex": f"- Test {metric}: ([0-9\\.]+)"} for metric in metrics),
        {"Name": "num_params", "Regex": "- Number of parameters: ([0-9\\.]+)"},
        {"Name": "num_trainable_params", "Regex": "- Number of trainable parameters: ([0-9\\.]+)"},
        {"Name": "num_non_trainable_params", "Regex": "- Number of non-trainable parameters: ([0-9\\.]+)"},
        {"Name": "train_duration", "Regex": "- Total training duration: ([0-9\\.]+)"},
        {"Name": "train_duration_per_epoch", "Regex": "- Average training duration per epoch: ([0-9\\.]+)"},
        {"Name": "test_evaluation_latency", "Regex": "- Test evaluation latency: ([0-9\\.]+)"},
        {"Name": "test_latency_per_sample", "Regex": "- Average test latency per sample: ([0-9\\.]+)"},
        {"Name": "test_throughput", "Regex": "- Average test throughput: ([0-9\\.]+)"},
    ]

    hyperparameters = sagemaker.hyperparameters.retrieve_default(model_id=model_id, model_version=model_version)
    
    estimator = create_jumpstart_estimator(
        model_id,
        role,
        job_name,
        s3_output_location,
        model_version=model_version,
        metric_definitions=metric_definitions,
        hyperparameters=hyperparameters
    )
    estimator.set_hyperparameters(**HYPERPARAMETERS)
    
    learning_rate = float(hyperparameters["learning_rate"])
    hyperparameter_ranges = {
        "learning_rate": sagemaker.tuner.CategoricalParameter([learning_rate, learning_rate / 5])
    }

    tuner = sagemaker.tuner.HyperparameterTuner(
        estimator,
        SM_AMT_OBJECTIVE_METRIC_NAME,
        hyperparameter_ranges,
        metric_definitions,
        max_jobs=SM_AMT_MAX_JOBS,
        max_parallel_jobs=SM_AMT_MAX_PARALLEL_TRAINING_JOBS_PER_TUNER,
        objective_type=objective_type,
        base_tuning_job_name=job_name,
    )
    return tuner

***
With these helper functions established, it is easy to create a `HyperparameterTuner` object for each specified model. But what happens if there is an error or the kerenel for this script is terminated? The hyperparameter tuning jobs would still run to completion and we would not want to re-launch these jobs in order to obtain our results. Therefore, we need yet another helper function that will either re-attach the hyperparameter tuning job if it exists or create a new one via `create_benchmarking_tuner`. To accomplish this, the JSON Lines file specified in `SAVE_TUNING_JOB_NAMES_FILE_PATH` is read and checked for whether a tuning job already exists for `model_name` and `dataset_name`. If it does exist, then the job is re-attached and returned. Otherwise, a new tuner is created and the `fit()` method is invoked with the `wait=False` argument and channels specified per the previously defined `S3DatasetSplit` object we used to store our dataset S3 location. We will have the thread wait for this job to complete later, but we first need to put this job information on the `queue_save_tuning_job` queue, which will indicate to the primary thread to append this job information to `SAVE_TUNING_JOB_NAMES_FILE_PATH`. Writing to this file needs to be done by the primary thread listener because multiple threads simultaneously writing to a file is not thread safe.
***

In [None]:
class JobInformation(NamedTuple):
    model_name: str
    dataset_name: str
    tuning_job_name: Optional[str] = None


def create_or_attach_tuner(
    model_name: str,
    dataset_name: str,
    unique_id: int,
    dataset_dict: Dict[str, S3DatasetSplit] = DATASET_DICT,
    tuning_job_names_file_path: Path = SAVE_TUNING_JOB_NAMES_FILE_PATH,
    session: sagemaker.Session = SM_SESSION
) -> Tuple[sagemaker.tuner.HyperparameterTuner, JobInformation]:
    if tuning_job_names_file_path.exists():
        tuning_jobs_df = pd.read_json(tuning_job_names_file_path, lines=True).set_index(["model_name", "dataset_name"])
        if (model_name, dataset_name) in tuning_jobs_df.index:
            tuning_job_name = tuning_jobs_df.loc[(model_name, dataset_name), "tuning_job_name"]
            tuner = sagemaker.tuner.HyperparameterTuner.attach(tuning_job_name, session)
            job_information = JobInformation(model_name, dataset_name, tuning_job_name)
            print(f"> Re-attached previous SageMaker tuning job, {job_information}")
            return tuner, job_information
        
    tuner = create_benchmarking_tuner(model_name, unique_id)
    dataset = dataset_dict[dataset_name]
    tuner.fit(dataset.channels(), wait=False)
    tuning_job_name = tuner.latest_tuning_job.name
    job_information = JobInformation(model_name, dataset_name, tuning_job_name)
    queue_save_tuning_job.put(job_information)
    print(f"> Starting new SageMaker tuning job, {job_information}")
    return tuner, job_information

***
Once a tuning job is complete, we need to obtain a description of the best training job. While the objective metric for the hyperparameter tuner is easily obtained via the `HyperparameterTuner.analytics()` method, which returns a `HyperparameterTuningJobAnalytics` object, the additional auxiliary metrics provided to the original estimator are not extracted with this object. The following function will probe the training job description in order to extract all metrics of interest to this benchmarking scenario from the key `FinalMetricDataList`.
***

In [None]:
def extract_metrics_from_logs(
    tuner: sagemaker.tuner.HyperparameterTuner,
    job_information: JobInformation,
    session: sagemaker.Session = SM_SESSION
) -> Dict[str, Any]:
    description = session.describe_training_job(tuner.best_training_job())
    metrics = {metric['MetricName']:  metric['Value'] for metric in description['FinalMetricDataList']}
    return {**metrics, **job_information._asdict()}

***
Next, we need to define a function that runs a single tuning job. For a given `model_id`, this function will do three things: 1) obtain a `HyperparameterTuner` object for this model, 2) launch the hyperparameter tuning job and wait for the job to complete, and 3) extract the relevant metrics from CloudWatch logs for this training job. Additionally, this function requests access to the `queue_currently_running` queue, which has a maximum capacity and will block without timeout until there is an available spot on the queue. This allows us to cap the number of sumultaneously running hyperparameter tuning jobs.
***

In [None]:
def run_tuner(model_name: str, dataset_name: str, unique_id: int) -> Dict[str, Any]:
    queue_currently_running.put(None)
    tuner, job_information = create_or_attach_tuner(model_name, dataset_name, unique_id)
    tuner.wait()
    metrics = extract_metrics_from_logs(tuner, job_information)
    print(f"> Completed SageMaker tuning job, {job_information}")
    return metrics

***
Finally, we need a couple of helper functions to log information to file. The first is intended to be triggered whenever the `create_or_attach_tuner` function puts job information onto the `queue_save_tuning_job` queue. Because we are using multithreading in this example and it is not thread safe to have multiple threads write to file simultaneously, we will have the primary script listening to the futures threads pass job information to be saved to this function. The second helper function here is intended to be called whenever a tuning job completes. It extracts the metrics as the return value of the future and writes a json line to file. It also prints out any exceptions generated by the future without raising an error to allow the remainder of jobs to complete. This prevents a single job failure from preventing any future analyses.
***

In [None]:
def append_tuning_job_to_file(
    job_information: JobInformation,
    file_path: Path = SAVE_TUNING_JOB_NAMES_FILE_PATH
) -> None:
    with open(file_path, "a+") as file:
        file.write(f"{json.dumps(job_information._asdict())}\n")
    print(f"> Saved job information to file, {job_information}")


def append_metrics_to_file(
    future: cf.Future,
    job_information: JobInformation,
    file_path: Path = SAVE_METRICS_FILE_PATH
) -> None:
    try:
        metrics = future.result()
        with open(file_path, "a+") as file:
            file.write(f"{json.dumps(metrics)}\n")
        print(f"> Saved metrics to file, {job_information}")
    except Exception as exc:
        print(f"> Exception generated for {job_information}: {exc}")
        traceback.print_exc()

## Train models
***
Everything is now in place to launch training jobs and aggregate performance metrics for the benchmarking evaluation. This notebook makes use of the Python standard library's [concurrent futures](https://docs.python.org/3/library/concurrent.futures.html) module, which is a high-level interface for asynchronously executing callables. The `run_tuner` function will be repetitively executed on a thread pool and the `queue_currently_running` queue will block any threads from launching aditional training instances until the number of currently running tuning jobs is less than `SM_AMT_MAX_PARALLEL_TUNING_JOBS`. Note that this queue would not be necessary if a `ProcessPoolExecutor` was used in place of `ThreadPoolExecutor`, but a process pool cannot share global state and therefore calling the functions `append_tuning_job_to_file` and `append_metrics_to_file` would not be thread safe.

Once all jobs are submitted to the executor, this script listens to the futures job pool. Until all jobs are completed, it will perform two tasks: 1) call `append_tuning_job_to_file` with any job information that gets populated into `queue_save_tuning_job`, and 2) call `append_metrics_to_file` for any future that has finished execution.

__FINAL NOTE__: Depending on the number of models and datasets defined above, this block may take a long time to run and consume a large number of resources. Please double check your settings!
***

In [None]:
if SAVE_METRICS_FILE_PATH.exists():
    SAVE_METRICS_FILE_PATH.unlink()

queue_save_tuning_job = queue.Queue()
queue_currently_running = queue.Queue(maxsize=SM_AMT_MAX_PARALLEL_TUNING_JOBS)

jobs = itertools.product(models, datasets)

with cf.ThreadPoolExecutor(max_workers=SM_AMT_MAX_PARALLEL_TUNING_JOBS) as executor:
    futures_to_job_information = {
        executor.submit(run_tuner, model_name, dataset_name, unique_id): JobInformation(model_name, dataset_name)
        for unique_id, (model_name, dataset_name) in enumerate(jobs)
    }
    
    while futures_to_job_information:
        done, not_done = cf.wait(futures_to_job_information, timeout=5.0, return_when=cf.FIRST_COMPLETED)

        while not queue_save_tuning_job.empty():
            job_information = queue_save_tuning_job.get()
            append_tuning_job_to_file(job_information)

        for future in done:
            queue_currently_running.get()
            job_information_before_execution = futures_to_job_information.pop(future)
            append_metrics_to_file(future, job_information_before_execution)

## Analyze results
***
At this point, all tuning jobs should have completed execution. Congratulations! Please check the file `SAVE_METRICS_FILE_PATH` to see that each job should have appended a JSON object to a new row in the file. Here, we read the contents of this file into a pandas `DataFrame` to view results in tabular form.
***

In [None]:
def model_name_clean(model_name: str):
    model_name = model_name.replace("tensorflow-ic-", "")
    model_name = model_name.replace("imagenet-", "")
    model_name = model_name.split("-classification")[0]
    return model_name

metrics_df = pd.read_json(SAVE_METRICS_FILE_PATH, lines=True)
metrics_df["model_name"] = metrics_df["model_name"].apply(model_name_clean)
metrics_df["model_category"] = metrics_df["model_name"].apply(lambda x: x.replace("tf2-preview-", "").split("-")[0])

display(metrics_df.sort_values(by=["dataset_name", "model_category", "model_name"]).set_index(["dataset_name", "model_category", "model_name"]))

***
With a pandas DataFrame of all performance metrics populated, you can perform whatever analysis is of interest. Here, we show a quick example of how to create a figure illustrating the pareto front tradeoff between validation accuracy and throughput. If using Jupyter Lab, be sure to enable to plotly Jupyter extension for best viewing results.
***

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
from plotly.graph_objs import Figure


def benchmarking_figure(
    df: pd.DataFrame,
    dataset_name: str,
    x: str = "test_throughput",
    y: str = "test_accuracy",
    title: str = "SageMaker JumpStart TensorFlow Image Classification Benchmarking",
    model_name: str = "model_name",
    xaxis_title: str = "throughput (images per second)",
    yaxis_title: str = "test accuracy",
    size: str = "num_params",
    color: str = "model_category",
    width=800,
    height=600
) -> Figure:
    
    df[f"sqrt_{size}"] = np.sqrt(df[size])
    df = df.sort_values(by=[model_name])
    df = df[df["dataset_name"]==dataset_name]
    
    fig = px.scatter(
        df,
        x=x,
        y=y,
        color=color,
        size=f"sqrt_{size}",
        title=f"{title} ({dataset_name})",
        hover_name=model_name,
        log_x=True,
        width=width,
        height=height,
    )
    fig.update_layout(
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
    )
    return fig

In [None]:
for dataset_name in datasets:
    fig = benchmarking_figure(metrics_df, dataset_name)
    fig.write_html(f"jumpstart_tf_ic_benchmarking_pareto_{dataset_name}.html")
    fig.show()