# Clarify post-training bias + MLflow logging

In [10]:
import os
import sys
import json
import subprocess
from time import gmtime, strftime

import boto3
import sagemaker
import pandas as pd

from sagemaker import clarify, image_uris
from sagemaker.session import Session

# ----------------------------------------------------------------------
# 1) ติดตั้ง MLflow + plugin
# ----------------------------------------------------------------------
# subprocess.check_call([sys.executable, "-m", "pip", "install", "mlflow==2.13.2", "sagemaker-mlflow==0.1.0"])

import mlflow as mlf
import sagemaker_mlflow  # activate SageMaker MLflow plugin

# ----------------------------------------------------------------------
# 2) SageMaker session / role / region
# ----------------------------------------------------------------------
sess: Session = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name
sm = boto3.client("sagemaker", region_name=region)

print("Region:", region)
print("Bucket:", bucket)
print("Role:", role)



Region: us-east-1
Bucket: sagemaker-us-east-1-423623839320
Role: arn:aws:iam::423623839320:role/service-role/SageMaker-ExecutionRole-20250705T232334


In [11]:
# โหลดตัวแปรจากขั้นก่อนหน้า
%store -r training_job_name
%store -r processed_train_data_s3_uri
%store -r processed_validation_data_s3_uri
%store -r processed_test_data_s3_uri

print("Training job name:         ", training_job_name)
print("Processed train S3:        ", processed_train_data_s3_uri)
print("Processed validation S3:   ", processed_validation_data_s3_uri)
print("Processed test S3:         ", processed_test_data_s3_uri)




Training job name:          sagemaker-xgboost-2025-12-03-08-55-12-951
Processed train S3:         s3://sagemaker-us-east-1-423623839320/sagemaker-scikit-learn-2025-12-03-07-46-41-610/output/retail-train
Processed validation S3:    s3://sagemaker-us-east-1-423623839320/sagemaker-scikit-learn-2025-12-03-07-46-41-610/output/retail-validation
Processed test S3:          s3://sagemaker-us-east-1-423623839320/sagemaker-scikit-learn-2025-12-03-07-46-41-610/output/retail-test


In [3]:
# ----------------------------------------------------------------------
# 4) สร้าง Model จาก training job (สำหรับ Clarify)
# ----------------------------------------------------------------------
inference_image_uri = image_uris.retrieve(
    framework="xgboost",
    region=region,
    version="1.7-1",
    py_version="py3",
    image_scope="inference",
)

print("Inference image URI:", inference_image_uri)

model_name = sess.create_model_from_job(
    training_job_name=training_job_name,
    image_uri=inference_image_uri,
)
print("Clarify will use model:", model_name)



Inference image URI: 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.7-1


Clarify will use model: sagemaker-xgboost-2025-12-03-08-55-12-951


# Verify

In [22]:
print("processed_test_data_s3_uri =", processed_test_data_s3_uri)
test_path = processed_test_data_s3_uri.rstrip("/") + "/test.csv"
print("test_path =", test_path)

!aws s3 ls $test_path

!mkdir -p data_check
!aws s3 cp $test_path data_check/test.csv


processed_test_data_s3_uri = s3://sagemaker-us-east-1-423623839320/sagemaker-scikit-learn-2025-12-03-07-46-41-610/output/retail-test
test_path = s3://sagemaker-us-east-1-423623839320/sagemaker-scikit-learn-2025-12-03-07-46-41-610/output/retail-test/test.csv
2025-12-03 07:48:45       3386 test.csv
download: s3://sagemaker-us-east-1-423623839320/sagemaker-scikit-learn-2025-12-03-07-46-41-610/output/retail-test/test.csv to data_check/test.csv


In [23]:
import pandas as pd
df_check = pd.read_csv("data_check/test.csv", nrows=5)
print("Columns in S3 test.csv:", df_check.columns.tolist())


Columns in S3 test.csv: ['record_id', 'date', 'store_id', 'day_of_week', 'is_weekend', 'is_holiday', 'holiday_name', 'max_temp_c', 'rainfall_mm', 'is_hot_day', 'is_rainy_day', 'base_price', 'discount_pct', 'is_promo', 'promo_type', 'final_price', 'units_sold', 'event_time', 'year', 'month', 'day', 'day_of_year', 'day_of_week_index', 'discount_amount', 'is_promo_or_holiday', 'high_demand', 'split_type']


# Solid Validate

In [24]:
# 1) ดึง test.csv (ล่าสุด) มาจาก processed_test_data_s3_uri
test_csv_s3_path = processed_test_data_s3_uri.rstrip("/") + "/test.csv"
!aws s3 cp $test_csv_s3_path ./clarify_raw_test.csv

df_test = pd.read_csv("./clarify_raw_test.csv")

# 2) ถ้าไม่มี high_demand ให้สร้างเลย
if "high_demand" not in df_test.columns:
    p75 = df_test["units_sold"].quantile(0.75)
    df_test["high_demand"] = (df_test["units_sold"] >= p75).astype(int)

# 3) เลือกเฉพาะฟีเจอร์ + high_demand
drop_cols_for_features = [
    "units_sold",
    "record_id",
    "event_time",
    "split_type",
    "date",
    "holiday_name",
    "promo_type",
    "day_of_week",
    "high_demand",
]
feature_cols_for_clarify = [c for c in df_test.columns if c not in drop_cols_for_features]
clarify_df = df_test[feature_cols_for_clarify + ["high_demand"]]

# 4) อัปโหลดไป prefix ใหม่ (ไม่มีของเก่าแน่นอน)
from time import gmtime, strftime
ts = strftime("%Y%m%d-%H%M%S", gmtime())

clarify_prefix = f"clarify/postbias-{training_job_name}-{ts}"
clarify_input_s3_uri = f"s3://{bucket}/{clarify_prefix}/clarify_test.csv"

clarify_df.to_csv("clarify_test.csv", index=False)
!aws s3 cp clarify_test.csv $clarify_input_s3_uri

print("Clarify input S3 =", clarify_input_s3_uri)


download: s3://sagemaker-us-east-1-423623839320/sagemaker-scikit-learn-2025-12-03-07-46-41-610/output/retail-test/test.csv to ./clarify_raw_test.csv
upload: ./clarify_test.csv to s3://sagemaker-us-east-1-423623839320/clarify/postbias-sagemaker-xgboost-2025-12-03-08-55-12-951-20251203-093340/clarify_test.csv
Clarify input S3 = s3://sagemaker-us-east-1-423623839320/clarify/postbias-sagemaker-xgboost-2025-12-03-08-55-12-951-20251203-093340/clarify_test.csv


In [25]:
# ----------------------------------------------------------------------
# 5) Clarify Configs
# ----------------------------------------------------------------------
bias_report_prefix = f"clarify/bias-report-{training_job_name}"
bias_post_report_output_path = f"s3://{bucket}/{bias_report_prefix}"
print("Bias report output S3:", bias_post_report_output_path)

# ❗ จุดสำคัญ: ไม่ส่ง headers ซ้ำ ถ้าไฟล์มี header แล้ว
data_config = clarify.DataConfig(
    s3_data_input_path=clarify_input_s3_uri,
    s3_output_path=bias_post_report_output_path,
    label="high_demand",
    dataset_type="text/csv",
)

# ModelConfig – Clarify ยิงไปที่ endpoint ของ XGBoost model
model_config = clarify.ModelConfig(
    model_name=model_name,
    instance_type="ml.m5.4xlarge",
    instance_count=1,
    content_type="text/csv",
    accept_type="text/csv",
)

# ModelPredictedLabelConfig – output คอลัมน์แรกจาก model
predictions_config = clarify.ModelPredictedLabelConfig(
    label=0,
)

# BiasConfig – facet = is_weekend, positive = high_demand == 1
bias_config = clarify.BiasConfig(
    label_values_or_threshold=[1],
    facet_name="is_weekend",
    facet_values_or_threshold=[1],
)

clarify_processor = clarify.SageMakerClarifyProcessor(
    role=role,
    instance_count=1,
    instance_type="ml.c5.2xlarge",
    sagemaker_session=sess,
)

def parse_s3_uri(s3_uri: str):
    if not s3_uri.startswith("s3://"):
        raise ValueError(f"Not a valid S3 URI: {s3_uri}")
    no_scheme = s3_uri[5:]
    bucket_name, _, key_prefix = no_scheme.partition("/")
    return bucket_name, key_prefix



INFO:sagemaker.image_uris:Ignoring unnecessary instance type: None.


Bias report output S3: s3://sagemaker-us-east-1-423623839320/clarify/bias-report-sagemaker-xgboost-2025-12-03-08-55-12-951


In [26]:
# ----------------------------------------------------------------------
# 6) MLflow config
# ----------------------------------------------------------------------
EXPERIMENT_NAME = "forcasting_demand_product"
MLFLOW_TRACKING_SERVER_ARN = (
    "arn:aws:sagemaker:us-east-1:423623839320:mlflow-tracking-server/tracking-server-demo"
)

mlf.set_tracking_uri(MLFLOW_TRACKING_SERVER_ARN)
mlf.set_experiment(EXPERIMENT_NAME)

suffix = strftime("%d-%H-%M-%S", gmtime())
run_name = f"clarify-post-bias-{suffix}"

post_training_methods = [
    "DPPL", "DI", "DCAcc", "DCR", "RD", "DAR", "DRR", "AD", "TE", "CDDPL", "FT"
]



In [27]:
# ----------------------------------------------------------------------
# 7) Run Clarify + download analysis.json + log เข้า MLflow
# ----------------------------------------------------------------------
with mlf.start_run(
    run_name=run_name,
    description="SageMaker Clarify post-training bias using high_demand label (p75 of units_sold)",
):
    mlf.log_param("clarify_model_name", model_name)
    mlf.log_param("clarify_input_s3_uri", clarify_input_s3_uri)
    mlf.log_param("bias_facet", "is_weekend")
    mlf.log_param("bias_positive_label_value", 1)

    print("Starting Clarify post-training bias job...")
    clarify_processor.run_post_training_bias(
        data_config=data_config,
        data_bias_config=bias_config,
        model_config=model_config,
        model_predicted_label_config=predictions_config,
        methods=post_training_methods,
        wait=True,
        logs=True,
    )
    print("Clarify job finished.")

    # ดึง analysis.json จาก S3
    s3_client = boto3.client("s3", region_name=region)
    bucket_name, key_prefix = parse_s3_uri(bias_report_output_path)
    analysis_s3_key = key_prefix.rstrip("/") + "/analysis.json"

    local_dir = f"./clarify_bias_{suffix}"
    os.makedirs(local_dir, exist_ok=True)
    local_analysis_path = os.path.join(local_dir, "analysis.json")

    print("Downloading analysis.json from S3:", f"s3://{bucket_name}/{analysis_s3_key}")
    s3_client.download_file(bucket_name, analysis_s3_key, local_analysis_path)

    # log report ทั้งก้อนเป็น artifact
    mlf.log_artifact(local_analysis_path)

    # Parse แล้ว log metrics ลง MLflow
    with open(local_analysis_path, "r") as f:
        analysis = json.load(f)

    post_metrics = analysis.get("post_training_bias_metrics", {})
    facets = post_metrics.get("facets", {})

    for facet_name, facet_entries in facets.items():
        for entry in facet_entries:
            facet_value = entry.get("value_or_threshold")
            metrics_list = entry.get("metrics", [])
            for m in metrics_list:
                short_name = m.get("name")     # เช่น DPPL, DI, ...
                value = m.get("value", None)
                if value is None:
                    continue
                metric_name = f"bias_post_{facet_name}_{facet_value}_{short_name}"
                print(metric_name, "=", value)
                mlf.log_metric(metric_name, value)

    print("Logged Clarify post-training bias metrics to MLflow.")

INFO:sagemaker.clarify:Analysis Config: {'dataset_type': 'text/csv', 'label': 'high_demand', 'label_values_or_threshold': [1], 'facet': [{'name_or_index': 'is_weekend', 'value_or_threshold': [1]}], 'methods': {'report': {'name': 'report', 'title': 'Analysis Report'}, 'post_training_bias': {'methods': ['DPPL', 'DI', 'DCAcc', 'DCR', 'RD', 'DAR', 'DRR', 'AD', 'TE', 'CDDPL', 'FT']}}, 'predictor': {'model_name': 'sagemaker-xgboost-2025-12-03-08-55-12-951', 'instance_type': 'ml.m5.4xlarge', 'initial_instance_count': 1, 'accept_type': 'text/csv', 'content_type': 'text/csv', 'label': 0}}
INFO:sagemaker:Creating processing-job with name Clarify-Posttraining-Bias-2025-12-03-09-34-41-539


Starting Clarify post-training bias job...
.................sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
We are not in a supported iso region, /bin/sh exiting gracefully with no changes.
INFO:sagemaker-clarify-processing:Starting SageMaker Clarify Processing job
INFO:analyzer.data_loading.data_loader_util:Analysis config path: /opt/ml/processing/input/config/analysis_config.json
INFO:analyzer.data_loading.data_loader_util:Analysis result path: /opt/ml/processing/output
INFO:analyzer.data_loading.data_loader_util:This host is algo-1.
INFO:analyzer.data_loading.data_loader_util:This host is the leader.
INFO:analyzer.data_loading.data_loader_util:Number of hosts in the cluster is 1.
INFO:sagemaker-clarify-processing:Running Python / Pandas based analyzer.
INFO:analyzer.data_loading.data_loader_factory:Dataset type: text/csv uri: /opt/ml/pr