# Module 3: Batch Scoring using a pre-trained XGBoost model
**This notebook uses the feature store to prepare test dataset for batch scoring and then use the XGBoost model trained in the model training notebook**

**Note:** Please set kernel to `Python 3 (Data Science)` and select instance to `ml.t3.medium`

---

## Contents

1. [Background](#Background)
1. [Setup](#Setup)
1. [Prepare test data](#Prepare-test-data)
1. [Batch Transform](#Batch-Transform)

## Background

After the model is trained, if the goal is to generate predictions on a large dataset where minimizing latency isn't a concern, then SageMaker batch transform is the solution. Functionally, batch transform uses the same mechanics as real-time hosting to generate predictions. It requires a web server that takes in HTTP POST requests a single observation, or mini-batch, at a time. However, unlike real-time hosted endpoints which have persistent hardware (instances stay running until you shut them down), batch transform clusters are torn down when the job completes.

In this example, we will walk through the steps to prepare the batch test dataset from feature store using athena CTAS query and perform batch transform with the test data available on Amazon S3. 

## Setup


In [None]:
import json
import logging
import sys
import time
import uuid
from time import sleep
from urllib.parse import urlparse
from io import StringIO

import boto3
import pandas as pd
import sagemaker
from sagemaker import get_execution_role
from sagemaker.utils import name_from_base

sys.path.append("..")

from utilities import Utils

In [None]:
logger = logging.getLogger("__name__")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

#### Essentials

In [None]:
sagemaker_execution_role = get_execution_role()
logger.info(f"Role = {sagemaker_execution_role}")
session = boto3.Session()
sagemaker_session = sagemaker.Session()
sagemaker_client = session.client(service_name="sagemaker")


default_bucket = sagemaker_session.default_bucket()
prefix = "sagemaker-featurestore-workshop"

In [None]:
s3 = boto3.resource("s3")


def list_s3_files(s3uri):
    parsed_url = urlparse(s3uri)
    bucket = s3.Bucket(parsed_url.netloc)
    prefix = parsed_url.path[1:]
    return [
        dict(bucket_name=k.bucket_name, key=k.key)
        for k in bucket.objects.filter(Prefix=prefix)
    ]

### Prepare test data for batch transform 
<!-- job using processing job with *AthenaDatasetDefinition* -->
We create the test dataset that we will use in our batch transform job using [*Athena CREATE TABLE AS SELECT (CTAS) query*](https://docs.aws.amazon.com/athena/latest/ug/ctas.html). A CTAS query creates a new table in Athena from the results of a SELECT statement from another query. Athena stores data files created by the CTAS statement in a specified location in Amazon S3.

We follow the steps below to prepare the test dataset for batch transform job:
1. firstly generates the list of feature names that we would like to read from the offline feature store by providing the feature group names as a list and an exclude feature list to the *generate_fsets* function. 
2. Construct an Athena SELECT query to get the expected test data from the target feature groups and then construct a CTAS query based on the first query to transform query results into *Parquet* format.

#### Generate the list of features needed from feature store.

We use boto3 sagemaker_client to perform `DescribeFeatureGroup` action to describe a FeatureGroup. The response includes information on the creation time, FeatureGroup name, the unique identifier for each FeatureGroup, and more, for more details of the response syntax, please refer to [document here](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeFeatureGroup.html#API_DescribeFeatureGroup_ResponseSyntax).

In [None]:
# Retrieve FG names
%store -r customers_feature_group_name
%store -r products_feature_group_name
%store -r orders_feature_group_name

customers_fg = sagemaker_client.describe_feature_group(
    FeatureGroupName=customers_feature_group_name
)
products_fg = sagemaker_client.describe_feature_group(
    FeatureGroupName=products_feature_group_name
)
orders_fg = sagemaker_client.describe_feature_group(
    FeatureGroupName=orders_feature_group_name
)

database_name = customers_fg["OfflineStoreConfig"]["DataCatalogConfig"]["Database"]
catalog = customers_fg["OfflineStoreConfig"]["DataCatalogConfig"]["Catalog"]

customers_table = customers_fg["OfflineStoreConfig"]["DataCatalogConfig"]["TableName"]
products_table = products_fg["OfflineStoreConfig"]["DataCatalogConfig"]["TableName"]
orders_table = orders_fg["OfflineStoreConfig"]["DataCatalogConfig"]["TableName"]

In [None]:
exclude_fsets = [
    "customer_id",
    "product_id",
    "order_id",
    "event_time",
    "purchase_amount",
    "n_days_since_last_purchase",
]

In [None]:
def generate_fsets(fg_list, exclude_fsets=None):
    _fg_lst = []
    for _fg in fg_list:
        _fg_tmp = pd.DataFrame(
            Utils.describe_feature_group(_fg["FeatureGroupName"])["FeatureDefinitions"]
        )
        if exclude_fsets:
            _fg_tmp = _fg_tmp[~_fg_tmp.FeatureName.isin(exclude_fsets)]

        _fg_lst.append(_fg_tmp)
    return pd.concat(_fg_lst, ignore_index=True)

In [None]:
fsets_df = generate_fsets([orders_fg, customers_fg, products_fg], exclude_fsets)
features_names = fsets_df.FeatureName.tolist()

#### Use Athena CTAS table query to generate test set for batch job



We start by create an Athena query to get the test data from feature store. We can use [PyAthena](https://pypi.org/project/pyathena/) (a library that uses [Athena's REST API](https://docs.aws.amazon.com/athena/latest/APIReference/Welcome.html) to connect to Athena and fetech query results) or [boto3 Athena](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html) (a low-level client representing Amazon Athena) to run queries that derive the feature sets from the Offline Feature Store. In this example, we can use CTAS statements to create new tables from existing tables on a subset of data, or a subset of columns. CTAS statements help reduce cost and improve performance by allowing users to run queries on smaller tables constructed from larger tables. When creating new tables using CTAS, you can include a WITH statement to define table-specific parameters, such as file format, compression, and partition columns. For more information, please refer to this [blog](https://aws.amazon.com/blogs/big-data/using-ctas-statements-with-amazon-athena-to-reduce-cost-and-improve-performance/). Note that the first column will be the unique identifier of customer id and the second column is the target value. Note that the query should only take the latest version of any given record that has multiple write times for the same event_time.

In [None]:
batch_transform_columns_string = ",\n    ".join(f'"{c}"' for c in features_names)

customer_uid = customers_fg["RecordIdentifierFeatureName"]
product_uid = products_fg["RecordIdentifierFeatureName"]
order_uid = orders_fg["RecordIdentifierFeatureName"]

customer_et = customers_fg["EventTimeFeatureName"]
product_et = products_fg["EventTimeFeatureName"]
order_et = orders_fg["EventTimeFeatureName"]

destination_s3_path = f's3://{default_bucket}/{prefix}/athena/data/sagemaker-batch-{time.strftime("%Y-%m-%d-%H%M%S")}/data'

temp_database_name = "sagemaker_processing"
table_name = f"sagemaker_tmp_{uuid.uuid4().hex[:8]}"

query_string = f"""WITH customer_table AS (
    SELECT *,
        dense_rank() OVER (
            PARTITION BY "{customer_uid}"
            ORDER BY "{customer_et}" DESC,
                "api_invocation_time" DESC,
                "write_time" DESC
        ) AS "rank"
    FROM "{customers_table}"
    WHERE NOT "is_deleted"
),
product_table AS (
    SELECT *,
        dense_rank() OVER (
            PARTITION BY "{product_uid}"
            ORDER BY "{product_et}" DESC,
                "api_invocation_time" DESC,
                "write_time" DESC
        ) AS "rank"
    FROM "{products_table}"
    WHERE NOT "is_deleted"
),
order_table AS (
    SELECT *,
        dense_rank() OVER (
            PARTITION BY "{order_uid}"
            ORDER BY "{order_et}" DESC,
                "api_invocation_time" DESC,
                "write_time" DESC
        ) AS "rank"
    FROM "{orders_table}"
    WHERE NOT "is_deleted"
)

SELECT DISTINCT
    "{order_uid}",
    {batch_transform_columns_string}
FROM customer_table,
    product_table,
    order_table
WHERE order_table."customer_id" = customer_table."customer_id"
    AND order_table."product_id" = product_table."product_id"
    AND customer_table."rank" = 1
    AND product_table."rank" = 1
    AND order_table."rank" = 1
"""

# create a temporary table with external_localtion pointing to the s3 location as save the query results as PARQUET
CTAS_query = f"""CREATE TABLE {catalog}.{temp_database_name}.{table_name}
WITH (external_location='{destination_s3_path}', format='PARQUET') 
AS {query_string}
"""
print(CTAS_query)

In [None]:
athena = boto3.client("athena")
glue = boto3.client("glue")

tmp_uri = f"s3://{default_bucket}/{prefix}/offline-store/query_results/"
logger.info(f"Running query on  database: {database_name}")
query_execution = athena.start_query_execution(
    QueryString=CTAS_query,
    QueryExecutionContext={"Database": database_name},
    ResultConfiguration={"OutputLocation": tmp_uri},
)
# wait for the Athena query to complete
query_execution_id = query_execution["QueryExecutionId"]

while True:
    query_response = athena.get_query_execution(QueryExecutionId=query_execution_id)
    query_state = query_response["QueryExecution"]["Status"]["State"]
    if query_state == "SUCCEEDED":
        # !aws s3 ls $destination_s3_path/
        print("", end="\r")
        [print(k["key"]) for k in list_s3_files(destination_s3_path)]
        break
    if query_state == "FAILED":
        logger.info(json.dumps(query_response, indent=2, default=str))
        break
    print(".", end="")
    sleep(0.2)
try:
    glue.delete_table(DatabaseName=temp_database_name, Name=table_name)
    logger.info("Temporary table removed from Glue Catalog")
except:
    logger.error("Failed to delete the temporary table in Glue Catalog")

## Batch Transform

In SageMaker Batch Transform, we introduced 3 new attributes - __input_filter__, __join_source__ and __output_filter__. In the below cell, we use the [SageMaker Python SDK](https://github.com/aws/sagemaker-python-sdk) to kick-off several Batch Transform jobs using different configurations of these 3 new attributes. Please refer to [this page](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html) to learn more about how to use them.

#### Create a model based on the pre-trained model artifacts on S3
Let's first create a model based on the training job from the previous notebook. We can use `describe_training_job` boto3 api call to get the model data uri. We use the [*XGBoostModel*](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/xgboost/model.py#L66) class (Framework Model) from the SageMaker SDK to create the model for batch transform. Note that you need to use the same framework version as the training job and also provide [an entry point script ](https://sagemaker.readthedocs.io/en/stable/frameworks/xgboost/using_xgboost.html#write-an-inference-script) that implements (at least) the `model_fn` function that calls the loaded model to get a prediction. In this example, we implemented the `input_fn` to handle the Parquet input format (the content type is "application/x-parquet") and transform the data to pandas dataframe. In the `predict_fn`, we filter out the ID column and target column and convert the feature columns from dataframe to *DMatrix* data type for prediction. We also associate the input with the prediction results by appending the prediction results as an additional column to the input matrix.

In [None]:
%store -r training_jobName

from sagemaker.xgboost.model import XGBoostModel

training_job_info = sagemaker_client.describe_training_job(
    TrainingJobName=training_jobName
)
xgb_model_data = training_job_info["ModelArtifacts"]["S3ModelArtifacts"]

xgb_model = XGBoostModel(
    source_dir="./code",
    entry_point="inference.py",
    framework_version="1.0-1",
    model_data=xgb_model_data,
    role=sagemaker_execution_role,
    name=name_from_base("fs-workshop-xgboost-model"),
    sagemaker_session=sagemaker_session,
)

We have used CTAS query to create a new table in Parquet format from the source tables, as Parquet is a widely used data storage format which is efficient and performant in both storage and processing. You can use the format property to specify ORC, AVRO, JSON, or TEXTFILE as the storage format for the new table. However, by default the output data is compressed which means we cannot save the data directly as CSV format (but rather the compressed CSV format with 'gz' extension). None of these output formats (with compression) is supported by batch transform directly, therefore, we need to either use a processing job to convert the data format to a supported one or treat each file as one record so the batch transform service doesn't need to read the file and slipt the records.

As mentioned above, although batch transform support [association of prediction results with input](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html) and can filter out columns in the input data when the input data is in the supported formats. The PARQUET data type is not supported by default (the supported input data formats are JSON- or CSV-), therefore, we will set the [batch strategy](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html#API_CreateTransformJob_RequestParameters) to `SingleRecord` which will read each parquet file as a whole and send the payload to the model for inference. The `max_concurrent_transforms` specifies the maximum number of parallel requests that can be sent to each instance in a transform job. Please note that as each file is used as one transform request, the default value for `MaxPayloadInMB` is 6MB. If the parquet file size for your test data is more than 6MB, you need to change the `max_payload` to be equal or greater than the maximum file size of the parquet files in S3. The `accept` parameter specifies the format of the batch transform output; where the `content_type` defines the MIME type of the input data.

In [None]:
output_path = f"s3://{default_bucket}/{prefix}/batch_output/{xgb_model.name}"

xgb_transformer = xgb_model.transformer(
    strategy="SingleRecord",
    instance_count=1,
    instance_type="ml.m5.xlarge",
    max_concurrent_transforms=4,
    accept="text/csv",
    output_path=output_path,
)
xgb_transformer.transform(
    destination_s3_path,
    content_type="application/x-parquet",  
# if you don't want to wait for the batch job to finish, uncomment below line and check the batch transform job status on the SageMaker console 
#     wait=False,   
)

Let's inspect the output of the Batch Transform job in S3. We can list all the generated files. They are in CSV format, to get PARQUET (or any other supported format) it is possible to add an `output_fn` to `inference.py` to convert the inference result.

In [None]:
output_file_list = list_s3_files(xgb_transformer.output_path)
[print(k["key"]) for k in output_file_list];

In [None]:
with pd.option_context("display.max_colwidth", 1200):
    display(
        pd.DataFrame(
            [k["key"].split("/")[-1] for k in output_file_list], columns=["File name"]
        )
    )

We can read read and combine the output into a Pandas dataframe

In [None]:
s3_obj = s3.Object(**output_file_list[0])
body = s3_obj.get()['Body']
csv_string = body.read().decode('utf-8')
pd.read_csv(
    StringIO(csv_string),
    header=None,
    names=features_names + ["prediction"]
)