In [1]:
%store -r model_a_s3_path
%store -r model_b_s3_path

In [2]:
%store -r s3_bucket
%store -r prefix

In [3]:
import sagemaker
from sagemaker import get_execution_role

session = sagemaker.Session()
role = get_execution_role()

In [4]:
from sagemaker.image_uris import retrieve

image_uri = retrieve(
    "xgboost", 
    region="us-east-1", 
    version="0.90-2"
)

image_uri

'683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3'

In [5]:
image_uri_a = image_uri
image_uri_b = image_uri

In [6]:
container1 = { 
    'Image': image_uri_a,
    'ContainerHostname': 'containerA',
    'ModelDataUrl': model_a_s3_path
}

container2 = { 
    'Image': image_uri_b,
    'ContainerHostname': 'containerB',
    'ModelDataUrl': model_b_s3_path
}

In [7]:
model_name = "ab-testing"
endpoint_config_name = 'ab-testing-config'
endpoint_name = 'ab-testing-endpoint'

In [9]:
import boto3
sm_client = boto3.Session().client('sagemaker')

In [10]:
model_name_a = "ab-model-a"
model_name_b = "ab-model-b"
endpoint_config_name = 'ab-endpoint-config'
endpoint_name = 'ab-endpoint'

In [11]:
try:
    sm_client.delete_model(ModelName=model_name_a)
    sm_client.delete_model(ModelName=model_name_b)
except:
    pass

In [12]:
response = sm_client.create_model(
    ModelName        = model_name_a,
    ExecutionRoleArn = role,
    Containers       = [container1])
print(response)

response = sm_client.create_model(
    ModelName        = model_name_b,
    ExecutionRoleArn = role,
    Containers       = [container2])
print(response)

{'ModelArn': 'arn:aws:sagemaker:us-east-1:581320662326:model/ab-model-a', 'ResponseMetadata': {'RequestId': '52a6a0cf-0f43-40d4-931b-1203c6536a34', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '52a6a0cf-0f43-40d4-931b-1203c6536a34', 'content-type': 'application/x-amz-json-1.1', 'content-length': '72', 'date': 'Mon, 07 Jun 2021 17:45:15 GMT'}, 'RetryAttempts': 0}}
{'ModelArn': 'arn:aws:sagemaker:us-east-1:581320662326:model/ab-model-b', 'ResponseMetadata': {'RequestId': 'd5e88ad9-eff0-4f63-983d-3ac817a7e040', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'd5e88ad9-eff0-4f63-983d-3ac817a7e040', 'content-type': 'application/x-amz-json-1.1', 'content-length': '72', 'date': 'Mon, 07 Jun 2021 17:45:16 GMT'}, 'RetryAttempts': 2}}


In [13]:
from sagemaker.session import production_variant

variant1 = production_variant(
    model_name=model_name_a,
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name='VariantA',
    initial_weight=0.5
)
                              
variant2 = production_variant(
    model_name=model_name_b,
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name='VariantB',
    initial_weight=0.5
)

In [14]:
sagemaker_session.endpoint_from_production_variants(
    name=endpoint_name,
    production_variants=[variant1, variant2]
)

---------------------!

'ab-endpoint'

In [15]:
runtime_sm_client = boto3.client('sagemaker-runtime')

In [24]:
body = "10,-5"

In [26]:
from time import sleep

def test_ab_testing_setup():
    response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='text/csv',
        Body=body)
    
    variant = response['InvokedProductionVariant']
    prediction = response['Body'].read().decode("utf-8")

    print(variant + " - "+ prediction)

for _ in range(0,10):
    test_ab_testing_setup()
    sleep(1)

VariantB - 0.8308258652687073
VariantA - 0.895996630191803
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantA - 0.895996630191803
VariantA - 0.895996630191803
VariantA - 0.895996630191803
VariantA - 0.895996630191803
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073


In [27]:
def test_direct_call():
    response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='text/csv',
        TargetVariant='VariantB',
        Body=body)

    variant = response['InvokedProductionVariant']
    prediction = response['Body'].read().decode("utf-8")

    print(variant + " - "+ prediction)

for _ in range(0,10):
    test_direct_call()
    sleep(1)

VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073


In [28]:
response = sm_client.delete_endpoint(
    EndpointName=endpoint_name
)

response

{'ResponseMetadata': {'RequestId': '842c0c49-78fc-4ada-8ff4-67ad5b2fea61',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '842c0c49-78fc-4ada-8ff4-67ad5b2fea61',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Mon, 07 Jun 2021 18:04:53 GMT'},
  'RetryAttempts': 0}}