In [34]:
%store -r model_a_s3_path
%store -r model_b_s3_path

In [35]:
%store -r s3_bucket
%store -r prefix

In [36]:
import sagemaker
from sagemaker import get_execution_role

session = sagemaker.Session()
role = get_execution_role()

In [37]:
from sagemaker.image_uris import retrieve

image_uri = retrieve(
    "xgboost", 
    region="us-east-1", 
    version="0.90-2"
)

image_uri

'683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3'

In [38]:
image_uri_a = image_uri
image_uri_b = image_uri

In [39]:
container1 = { 
    'Image': image_uri_a,
    'ContainerHostname': 'containerA',
    'ModelDataUrl': model_a_s3_path
}

container2 = { 
    'Image': image_uri_b,
    'ContainerHostname': 'containerB',
    'ModelDataUrl': model_b_s3_path
}

In [40]:
model_name = "ab-testing"
endpoint_config_name = 'ab-testing-config'
endpoint_name = 'ab-testing-endpoint'

In [41]:
import boto3
sm_client = boto3.Session().client('sagemaker')

In [42]:
model_name_a = "ab-model-a"
model_name_b = "ab-model-b"
endpoint_config_name = 'ab-endpoint-config'
endpoint_name = 'ab-endpoint'

In [43]:
try:
    sm_client.delete_model(ModelName=model_name_a)
    sm_client.delete_model(ModelName=model_name_b)
except:
    pass

In [44]:
response = sm_client.create_model(
    ModelName        = model_name_a,
    ExecutionRoleArn = role,
    Containers       = [container1])
print(response)

response = sm_client.create_model(
    ModelName        = model_name_b,
    ExecutionRoleArn = role,
    Containers       = [container2])
print(response)

{'ModelArn': 'arn:aws:sagemaker:us-east-1:581320662326:model/ab-model-a', 'ResponseMetadata': {'RequestId': '7fc36097-902f-4afa-8e61-931d3b1b246a', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '7fc36097-902f-4afa-8e61-931d3b1b246a', 'content-type': 'application/x-amz-json-1.1', 'content-length': '72', 'date': 'Tue, 08 Jun 2021 07:54:36 GMT'}, 'RetryAttempts': 0}}
{'ModelArn': 'arn:aws:sagemaker:us-east-1:581320662326:model/ab-model-b', 'ResponseMetadata': {'RequestId': 'c50a537b-469c-424d-84bc-337192bbd430', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'c50a537b-469c-424d-84bc-337192bbd430', 'content-type': 'application/x-amz-json-1.1', 'content-length': '72', 'date': 'Tue, 08 Jun 2021 07:54:40 GMT'}, 'RetryAttempts': 3}}


In [45]:
from sagemaker.session import production_variant

variant1 = production_variant(
    model_name=model_name_a,
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name='VariantA',
    initial_weight=0.5
)
                              
variant2 = production_variant(
    model_name=model_name_b,
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name='VariantB',
    initial_weight=0.5
)

In [46]:
session.endpoint_from_production_variants(
    name=endpoint_name,
    production_variants=[variant1, variant2]
)

---------------------!

'ab-endpoint'

In [47]:
runtime_sm_client = boto3.client('sagemaker-runtime')

In [48]:
body = "10,-5"

In [49]:
from time import sleep

def test_ab_testing_setup():
    response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='text/csv',
        Body=body
    )
    
    variant = response['InvokedProductionVariant']
    b = response['Body'].read()
    prediction = b.decode("utf-8")

    print(variant + " - "+ prediction)

for _ in range(0,10):
    test_ab_testing_setup()
    sleep(1)

VariantA - 0.895996630191803
VariantA - 0.895996630191803
VariantA - 0.895996630191803
VariantA - 0.895996630191803
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantA - 0.895996630191803
VariantA - 0.895996630191803
VariantA - 0.895996630191803


In [50]:
def test_direct_call():
    response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='text/csv',
        TargetVariant='VariantB',
        Body=body
    )

    variant = response['InvokedProductionVariant']
    b = response['Body'].read()
    prediction = b.decode("utf-8")

    print(variant + " - "+ prediction)

for _ in range(0,10):
    test_direct_call()
    sleep(1)

VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073
VariantB - 0.8308258652687073


In [51]:
response = sm_client.delete_endpoint(
    EndpointName=endpoint_name
)

response

{'ResponseMetadata': {'RequestId': '36e82d08-54ac-4e79-82ec-35267e91c906',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '36e82d08-54ac-4e79-82ec-35267e91c906',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Tue, 08 Jun 2021 08:05:33 GMT'},
  'RetryAttempts': 0}}