# A/B Testing with Amazon SageMaker

In production ML workflows, data scientists and data engineers frequently try to improve their models in various ways, such as by performing [Perform Automatic Model Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html), training on additional or more-recent data, and improving feature selection. Performing A/B testing between a new model and an old model with production traffic can be an effective final step in the validation process for a new model. In A/B testing, you test different variants of your models and compare how each variant performs relative to each other. You then choose the best-performing model to replace a previously-existing model new version delivers better performance than the previously-existing version.

Amazon SageMaker enables you to test multiple models or model versions behind the same endpoint using production variants. Each production variant identifies a machine learning (ML) model and the resources deployed for hosting the model. You can distribute endpoint invocation requests across multiple production variants by providing the traffic distribution for each variant, or you can invoke a specific variant directly for each request.

In this notebook we'll:
* Evaluate models by invoking specific variants
* Gradually release a new model by specifying traffic distribution

Reference notebook example: [A/B Testing with Amazon SageMaker](https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker_endpoints/a_b_testing/a_b_testing.ipynb)

### Configuration
Let's set up some required imports and basic initial variables:

In [None]:
%matplotlib inline
import datetime
import time
import os
import boto3
import re
import json
import pandas as pd
import numpy as np
import sagemaker
from sagemaker import get_execution_role, session
from sagemaker.s3 import S3Downloader, S3Uploader

sm_session = sagemaker.Session()
role = get_execution_role()
bucket = sm_session.default_bucket()
region = boto3.Session().region_name
sm_client = boto3.client("sagemaker", region)
sm_runtime = boto3.Session().client("sagemaker-runtime")
prefix = "sagemaker/huggingface-pytorch-sentiment-analysis"
time_now = f'{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}'
time_now

In [None]:
%store
%store -r

### Step 1: Deploy the models created in the previous multi-model endpoint notebook



In [None]:
def production_variant(
    model_name,
    instance_type=None,
    initial_instance_count=None,
    variant_name="AllTraffic",
    initial_weight=1,
    accelerator_type=None,
    serverless_inference_config=None,
):
    """Create a production variant description suitable for use in a ``ProductionVariant`` list.
    This is also part of a ``CreateEndpointConfig`` request.
    Args:
        model_name (str): The name of the SageMaker model this production variant references.
        instance_type (str): The EC2 instance type for this production variant. For example,
            'ml.c4.8xlarge'.
        initial_instance_count (int): The initial instance count for this production variant
            (default: 1).
        variant_name (string): The ``VariantName`` of this production variant
            (default: 'AllTraffic').
        initial_weight (int): The relative ``InitialVariantWeight`` of this production variant
            (default: 1).
        accelerator_type (str): Type of Elastic Inference accelerator for this production variant.
            For example, 'ml.eia1.medium'.
            For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
        serverless_inference_config (dict): Specifies configuration dict related to serverless
            endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig
            object (default: None)
    Returns:
        dict[str, str]: An SageMaker ``ProductionVariant`` description
    """
    production_variant_configuration = {
        "ModelName": model_name,
        "VariantName": variant_name,
        "InitialVariantWeight": initial_weight,
    }

    if accelerator_type:
        production_variant_configuration["AcceleratorType"] = accelerator_type

    if serverless_inference_config:
        production_variant_configuration["ServerlessConfig"] = serverless_inference_config
    else:
        initial_instance_count = initial_instance_count or 1
        production_variant_configuration["InitialInstanceCount"] = initial_instance_count
        production_variant_configuration["InstanceType"] = instance_type

    return production_variant_configuration

In [None]:
variant1 = production_variant(
    model_name=roberta_mme_model_name,
    instance_type="ml.c5.2xlarge",
    initial_instance_count=1,
    variant_name="Variant1",
    initial_weight=1,
)
variant2 = production_variant(
    model_name=distilbert_model_name,
    instance_type="ml.c5.xlarge",
    initial_instance_count=1,
    variant_name="Variant2",
    initial_weight=1,
)

(variant1, variant2)

#### Deploy
Let's go ahead and deploy our two variants to a SageMaker endpoint:

In [None]:
def create_endpoint(endpoint_name, config_name, tags=None):
    """Create an Amazon SageMaker ``Endpoint`` according to the configuration in the request.
        Once the ``Endpoint`` is created, client applications can send requests to obtain
        inferences. The endpoint configuration is created using the ``CreateEndpointConfig`` API.
        Args:
            endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` being created.
            config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy.
            wait (bool): Whether to wait for the endpoint deployment to complete before returning
                (default: True).
        Returns:
            str: Name of the Amazon SageMaker ``Endpoint`` created.
    """
    print("Creating endpoint with name {}".format(endpoint_name))

    tags = tags or []

    sm_client.create_endpoint(
        EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags
    )
    return endpoint_name

def endpoint_from_production_variants(
    name,
    production_variants,
    tags=None,
    kms_key=None,
    data_capture_config_dict=None,
    async_inference_config_dict=None,
):
    """Create an SageMaker ``Endpoint`` from a list of production variants.
    Args:
        name (str): The name of the ``Endpoint`` to create.
        production_variants (list[dict[str, str]]): The list of production variants to deploy.
        tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint
            (default: None).
        kms_key (str): The KMS key that is used to encrypt the data on the storage volume
            attached to the instance hosting the endpoint.
        wait (bool): Whether to wait for the endpoint deployment to complete before returning
            (default: True).
        data_capture_config_dict (dict): Specifies configuration related to Endpoint data
            capture for use with Amazon SageMaker Model Monitoring. Default: None.
        async_inference_config_dict (dict) : specifies configuration related to async endpoint.
            Use this configuration when trying to create async endpoint and make async inference
            (default: None)
    Returns:
        str: The name of the created ``Endpoint``.
    """
    config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
    if tags:
        config_options["Tags"] = tags
    if kms_key:
        config_options["KmsKeyId"] = kms_key
    if data_capture_config_dict is not None:
        config_options["DataCaptureConfig"] = data_capture_config_dict
    if async_inference_config_dict is not None:
        config_options["AsyncInferenceConfig"] = async_inference_config_dict

    print("Creating endpoint-config with name {}".format(name))
    sm_client.create_endpoint_config(**config_options)

    return create_endpoint(endpoint_name=name, config_name=name, tags=tags)

In [None]:
endpoint_name = f"demo-hf-pytorch-variant-{time_now}"
print(f"EndpointName={endpoint_name}")

endpoint_from_production_variants(
    name=endpoint_name, production_variants=[variant1, variant2]
)

In [None]:
describe_endpoint_response = sm_client.describe_endpoint(EndpointName=endpoint_name)

while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = sm_client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(20)

describe_endpoint_response

## Step 2: Invoke the deployed models

You can now send data to this endpoint to get inferences in real time.



In [None]:
test_data = pd.read_csv("../sample_payload/test_data.csv", header=None)
json_data = dict({'inputs':test_data.iloc[:,0].to_list()})
batch_data = pd.read_csv("../sample_payload/batch_data.csv", header=None)

In [None]:
%%time
predictions = []

for i in range(5):
    response = sm_runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(json_data),
        ContentType="application/json",
    )
    predictions.append(response["Body"].read().decode("utf-8"))
    time.sleep(0.5)

print(*predictions, sep='\n')

### Invoke a specific variant

Now, let’s use the new feature that was released today to invoke a specific variant. For this, we simply use the new parameter to define which specific ProductionVariant we want to invoke. Let us use this to invoke Variant1 for all requests.

In [None]:
%%time
response = sm_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(json_data),
    ContentType="application/json",
    TargetVariant=variant1["VariantName"],
)

print(response["Body"].read())

In [None]:
%%time
response = sm_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(json_data),
    ContentType="application/json",
    TargetVariant=variant2["VariantName"],
)

print(response["Body"].read())

## Step 3: Evaluate variant performance

### Evaluating Variant 1

Using the new targeting feature, let us evaluate the accuracy, precision, recall, F1 score, and ROC/AUC for Variant1:

Note that the test data was from [Kaggle financial sentiment analysis dataset](https://www.kaggle.com/datasets/sbhatti/financial-sentiment-analysis)

In [None]:
import io
import csv
import json
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import roc_auc_score

df_data = pd.read_csv("../sample_payload/batch_data.csv")
source_data = df_data.to_json(orient='records')
json_lst = json.loads(source_data)
json_lst[0]

In [None]:
def invoke_with_single_sentence(list_data, endpoint_name, variant_name):
    print(f"Sending test traffic to the endpoint {endpoint_name}. \nPlease wait...")
    predictions = []
    for payload in list_data:
        print(".", end="", flush=True)
        response = sm_runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType="application/json",
            Body=json.dumps(payload),
            TargetVariant=variant_name,
        )
        predictions.append(response["Body"].read().decode("utf-8"))
        time.sleep(0.5)
    print('\nDone!')
    return predictions



In [None]:
predictions1 = invoke_with_single_sentence(json_lst, endpoint_name, variant1["VariantName"])

In [None]:
df = pd.DataFrame(columns=['label','score'])
for prediction in predictions1:
    tmp_df = pd.DataFrame(json.loads(prediction)[0])
    new_row = tmp_df[tmp_df['score']==max(tmp_df['score'])]
    df = df.append(new_row, ignore_index=True)
df.head()

In [None]:
value_map = {'LABEL_0': 0, 'LABEL_1': 1, 'LABEL_2': 2}
df = df.replace({'label': value_map})
df.head()

In [None]:
# Let's get the labels of our test set; we will use these to evaluate our predictions
df_with_labels = pd.read_csv("../sample_payload/batch_data_groundtruth.csv")

value_map = {'negative': 0, 'neutral': 1, 'positive': 2}
df_with_labels = df_with_labels.replace({'sentiment': value_map})

In [None]:
test_labels = df_with_labels.iloc[:, 1]
labels = test_labels.to_numpy()
preds = df.label.to_numpy()

# Calculate accuracy
accuracy = sum(preds == labels) / len(labels)
print(f"Accuracy: {accuracy}")


### Next, we collect data for Variant2

In [None]:
predictions2 = invoke_with_single_sentence(json_lst, endpoint_name, variant2["VariantName"])

In [None]:
df2 = pd.DataFrame(columns=['label','score'])
for prediction in predictions2:
    tmp_df = pd.DataFrame(json.loads(prediction))
    new_row = tmp_df[tmp_df['score']==max(tmp_df['score'])]
    df2 = df2.append(new_row, ignore_index=True)
df2.head()

In [None]:
value_map = {'NEGATIVE': 0, 'POSITIVE': 1}
df2 = df2.replace({'label': value_map})
df2.head()

In [None]:
preds = df2.label.to_numpy()

# Calculate accuracy
accuracy = sum(preds == labels) / len(labels)
print(f"Accuracy: {accuracy}")

## Step 4: Dialing up our chosen variant in production

Now that we have determined Variant1 to be better as compared to Variant2, we will shift more traffic to it. 

We can continue to use TargetVariant to continue invoking a chosen variant. A simpler approach is to update the weights assigned to each variant using UpdateEndpointWeightsAndCapacities. This changes the traffic distribution to your production variants without requiring updates to your endpoint. 

Recall our variant weights are as follows:

In [None]:
{
    variant["VariantName"]: variant["CurrentWeight"]
    for variant in sm_client.describe_endpoint(EndpointName=endpoint_name)["ProductionVariants"]
}

We'll first write a method to easily invoke our endpoint (a copy of what we had been previously doing):

In [None]:
def invoke_endpoint_for_two_minutes():
    with open("../sample_payload/batch_data.csv", "r") as f:
        count=0
        for row in f:
            print(".", end="", flush=True)
            payload = row.rstrip("\n")
            response = sm_runtime.invoke_endpoint(
                EndpointName=endpoint_name, ContentType="text/csv", Body=payload
            )
            response["Body"].read().decode("utf-8")
            time.sleep(1)

In [None]:
cw = boto3.Session().client("cloudwatch")

def get_invocation_metrics_for_endpoint_variant(endpoint_name, variant_name, start_time, end_time):
    metrics = cw.get_metric_statistics(
        Namespace="AWS/SageMaker",
        MetricName="Invocations",
        StartTime=start_time,
        EndTime=end_time,
        Period=60,
        Statistics=["Sum"],
        Dimensions=[
            {"Name": "EndpointName", "Value": endpoint_name},
            {"Name": "VariantName", "Value": variant_name},
        ],
    )
    return (
        pd.DataFrame(metrics["Datapoints"])
        .sort_values("Timestamp")
        .set_index("Timestamp")
        .drop("Unit", axis=1)
        .rename(columns={"Sum": variant_name})
    )


def plot_endpoint_metrics(start_time=None):
    start_time = start_time or datetime.now() - timedelta(minutes=60)
    end_time = datetime.datetime.now()
    metrics_variant1 = get_invocation_metrics_for_endpoint_variant(
        endpoint_name, variant1["VariantName"], start_time, end_time
    )
    metrics_variant2 = get_invocation_metrics_for_endpoint_variant(
        endpoint_name, variant2["VariantName"], start_time, end_time
    )
    metrics_variants = metrics_variant1.join(metrics_variant2, how="outer")
    metrics_variants.plot()
    return metrics_variants

We invoke our endpoint for a bit, to show the even split in invocations:

In [None]:
invocation_start_time = datetime.datetime.now()
invoke_endpoint_for_two_minutes()
time.sleep(20)  # give metrics time to catch up
plot_endpoint_metrics(invocation_start_time)

Now let us shift 75% of the traffic to Variant1 by assigning new weights to each variant using UpdateEndpointWeightsAndCapacities. Amazon SageMaker will now send 75% of the inference requests to Variant1 and remaining 25% of requests to Variant2. 

In [None]:
sm_client.update_endpoint_weights_and_capacities(
    EndpointName=endpoint_name,
    DesiredWeightsAndCapacities=[
        {"DesiredWeight": 75, "VariantName": variant1["VariantName"]},
        {"DesiredWeight": 25, "VariantName": variant2["VariantName"]},
    ],
)

In [None]:
print("Waiting for update to complete")
while True:
    status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
    if status in ["InService", "Failed"]:
        print("Done")
        break
    print(".", end="", flush=True)
    time.sleep(1)

{
    variant["VariantName"]: variant["CurrentWeight"]
    for variant in sm_client.describe_endpoint(EndpointName=endpoint_name)["ProductionVariants"]
}

Now let's check how that has impacted invocation metrics:

In [None]:
invoke_endpoint_for_two_minutes()
time.sleep(20)  # give metrics time to catch up
plot_endpoint_metrics(invocation_start_time)

We can continue to monitor our metrics and when we're satisfied with a variant's performance, we can route 100% of the traffic over the variant. We used UpdateEndpointWeightsAndCapacities to update the traffic assignments for the variants. The weight for Variant1 is set to 0 and the weight for Variant2 is set to 1. Therefore, Amazon SageMaker will send 100% of all inference requests to Variant2.

In [None]:
sm_client.update_endpoint_weights_and_capacities(
    EndpointName=endpoint_name,
    DesiredWeightsAndCapacities=[
        {"DesiredWeight": 1, "VariantName": variant1["VariantName"]},
        {"DesiredWeight": 0, "VariantName": variant2["VariantName"]},
    ],
)
print("Waiting for update to complete")
while True:
    status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
    if status in ["InService", "Failed"]:
        print("Done")
        break
    print(".", end="", flush=True)
    time.sleep(1)

{
    variant["VariantName"]: variant["CurrentWeight"]
    for variant in sm_client.describe_endpoint(EndpointName=endpoint_name)["ProductionVariants"]
}

In [None]:
invoke_endpoint_for_two_minutes()
time.sleep(20)  # give metrics time to catch up
plot_endpoint_metrics(invocation_start_time)

The Amazon CloudWatch metrics for the total invocations for each variant below shows us that all inference requests are being processed by Variant1 and there are no inference requests processed by Variant2.

You can now safely update your endpoint and delete Variant2 from your endpoint. You can also continue testing new models in production by adding new variants to your endpoint and following steps 2 - 4. 

## Delete the endpoint

If you do not plan to use this endpoint further, you should delete the endpoint to avoid incurring additional charges.

In [None]:
sm_session.delete_endpoint(endpoint_name)