# SageMaker Model Deployment

This notebook demonstrates how to deploy a trained XGBoost model using AWS SageMaker.

## Task 1: Setup Environment

In [None]:
# Install dependencies
%pip install seaborn
%reset -f

import boto3
from botocore.exceptions import ClientError
import pandas as pd
import sagemaker
from sagemaker import image_uris
from time import gmtime, strftime

role = sagemaker.get_execution_role()
region = boto3.Session().region_name
boto3_session = boto3.Session()
sagemaker_client = boto3_session.client('sagemaker')
sagemaker_runtime = boto3_session.client('sagemaker-runtime')
s3_client = boto3.client('s3')

## Task 2: Check if S3 Bucket and Test Data File Exist

In [None]:
# Check if S3 bucket exists
from botocore.exceptions import ClientError


def check_s3_bucket(bucket_name: str) -> bool:
    """Check if the S3 bucket exists and is accessible.
    
    Args:
        bucket_name (str): Name of the S3 bucket.
    
    Returns:
        bool: True if bucket exists and accessible, False otherwise.
    """
    s3 = boto3.client('s3')
    try:
        s3.head_bucket(Bucket=bucket_name)
        return True
    except ClientError:
        return False


def check_s3_object(bucket_name: str, key: str) -> bool:
    """Check whether an S3 object exists and is accessible."""
    try:
        s3_client.head_object(Bucket=bucket_name, Key=key)
        return True
    except ClientError as e:
        error_code = e.response["Error"]["Code"]
        if error_code in ("404", "NoSuchKey"):
            return False
        elif error_code == "403":
            raise PermissionError(f"Access denied for s3://{bucket_name}/{key}")
        else:
            raise


def get_user_input(prompt: str) -> str:
    """Prompt user for input and ensure it's not empty.
    
    Args:
        prompt (str): Prompt text to display.
    
    Returns:
        str: User input.
    """
    while True:
        value = input(prompt).strip()
        if value:
            return value
        print("Input cannot be empty. Please try again.")


# -------------------------
# Interactive inputs
# -------------------------
bucket_name = get_user_input("Enter the S3 bucket name: ")
test_data_filepath = get_user_input("Enter 'test' data file path which contains test data: ")

# -------------------------
# Check bucket existence
# -------------------------
if not check_s3_bucket(bucket_name):
    raise ValueError(f"S3 Bucket '{bucket_name}' does not exist or you don't have access!")

# -------------------------
# Check 'test' data file existence
# -------------------------
if not check_s3_object(bucket_name, test_data_filepath):
    raise ValueError(f"'test' data file '{test_data_filepath}' does not exist or you don't have access!")

print(f"S3 Bucket '{bucket_name}' exists ✅")
print(f"'test' data file '{test_data_filepath}' exists ✅")

## Task 3: Download Test Dataset from S3

In [None]:
s3_client.download_file(bucket_name, test_data_filepath, 'adult_data_processed_test.csv')

## Task 4: Remove Labels from the Test Dataset

In [None]:
INPUT_FILE = "adult_data_processed_test.csv"
OUTPUT_FILE = "adult_data_processed_test_no_target.csv"

# Load data
df = pd.read_csv(INPUT_FILE)

# Extract target column (labels)
df_labels = df.iloc[:, 0]

# Remove target column from dataframe
df = df.drop(df.columns[0], axis=1)

# Save remaining features to new CSV (no index, no header)
df.to_csv(OUTPUT_FILE, index=False, header=False)

## Task 5: Setup the Model

In [None]:
# -----------------------------
# Constants
# -----------------------------
FRAMEWORK_NAME = "xgboost"
FRAMEWORK_VERSION = "1.7-1"

# -----------------------------
# Generate a unique run name
# -----------------------------
run_timestamp = strftime("%Y%m%d-%H%M%S")
model_name = f"vijay-xgboost-income-model-{run_timestamp}"

# -----------------------------
# Model artifact location
# -----------------------------
MODEL_DATA_S3_PATH = f"scripts/data/models/model.tar.gz"
MODEL_DATA_S3_URL = f"s3://{bucket_name}/scripts/data/models/model.tar.gz"

# -------------------------
# Check Model file existence
# -------------------------
if not check_s3_object(bucket_name, MODEL_DATA_S3_PATH):
    raise ValueError(f"Model data file '{MODEL_DATA_S3_PATH}' does not exist or you don't have access!")

print(f"Model file '{MODEL_DATA_S3_PATH}' exists ✅")

# -----------------------------
# Retrieve XGBoost container URI
# -----------------------------
container_uri = image_uris.retrieve(
    framework=FRAMEWORK_NAME,
    region=region,
    version=FRAMEWORK_VERSION
)

# -----------------------------
# Create the model in SageMaker
# -----------------------------
income_model = sagemaker_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": container_uri,
        "ModelDataUrl": MODEL_DATA_S3_URL,
    },
)

print(f"Created SageMaker model: {model_name}")

## Task 6: Configure an Endpoint

In [None]:
create_date = strftime("%Y%m%d-%H%M%S")
ENDPOINT_CONFIG_NAME = f"vijay-income-model-real-time-endpoint-{create_date}"
INSTANCE_TYPE = "ml.m5.xlarge"
INITIAL_SAMPLING_PERCENTAGE = 25  # The percentage of requests SageMaker AI will capture. A lower value is recommended for Endpoints with high traffic.
CAPTURE_MODES = ["Input", "Output"]
DATA_CAPTURE_S3_URI = f"s3://{bucket_name}/data-capture"

# -----------------------------
# Create endpoint config
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html
# -----------------------------
# DataCaptureConfig is a feature of SageMaker endpoints that allows you to automatically record inference requests and responses and store them in S3.
endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=ENDPOINT_CONFIG_NAME,
    ProductionVariants=[
        {
            "VariantName": "VijayTestModel1",
            "ModelName": model_name,
            "InstanceType": INSTANCE_TYPE,
            "InitialInstanceCount": 1,
        }
    ],
    DataCaptureConfig={
        'EnableCapture': True,  # Whether data should be captured or not.
        'InitialSamplingPercentage': INITIAL_SAMPLING_PERCENTAGE,
        'DestinationS3Uri': f's3://{bucket_name}/data-capture',
        'CaptureOptions': [{"CaptureMode": capture_mode} for capture_mode in CAPTURE_MODES]
    }
)

print(f"Created EndpointConfig: {endpoint_config_response['EndpointConfigArn']}")

## Task 7: Create the Endpoint

In [None]:
# Create a unique endpoint name based on the endpoint configuration
endpoint_name = f"{ENDPOINT_CONFIG_NAME}-endpoint"

create_endpoint_response = sagemaker_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=ENDPOINT_CONFIG_NAME,
)

print(f"Endpoint creation started: {create_endpoint_response['EndpointArn']}")

## Task 8: Check Status of Endpoint Creation

In [None]:
response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
status = response.get("EndpointStatus")
print(status)

# Poll endpoint status until it is created
while status == "Creating":
    print("Waiting for endpoint creation...")
    time.sleep(15)
    response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
    status = response.get("EndpointStatus")

# Handle failure
if status != "InService":
    failure_reason = response.get("FailureReason", "Unknown reason")
    endpoint_arn = create_endpoint_response.get("EndpointArn", "Unknown ARN")
    print(f"Failed to create endpoint. Status: {status}")
    print(f"Response: {response}")
    raise SystemExit(
        f"Failed to create endpoint {endpoint_arn}. "
        f"Status: {status}. Reason: {failure_reason}"
    )

# Success message
endpoint_arn = create_endpoint_response.get("EndpointArn", "Unknown ARN")
print(f"Endpoint {endpoint_arn} successfully created.")

## Task 9: Generate Predictions

In [None]:
# Set the cutoff value for binary classification
CUTOFF_VALUE = 0.5

def convert_probability_to_binary(probability: float) -> int:
    """Convert probability to binary class using cutoff value."""
    return 1 if probability >= CUTOFF_VALUE else 0


print(f"Sending test traffic to the endpoint {endpoint_name}. \nPlease wait...")

predictions = []

with open("adult_data_processed_test_no_target.csv", "r") as f:
    for row in f:
        payload = row.strip()

        response = sm_runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType="text/csv",
            Body=payload
        )

        pred_probability = float(response["Body"].read().decode("utf-8"))
        predictions.append(convert_probability_to_binary(pred_probability))

# Convert predictions list to numpy array
pred_np = np.array(predictions, dtype=int)

print("Done!")