# Partitioned Inference Batch
Load inference data, run TABLE(model!PREDICT(...) OVER (PARTITION BY stats_ntile_group)), save to INFERENCE_LOGS.


In [None]:
from snowflake.snowpark.context import get_active_session
from snowflake.ml.registry import Registry
import time

session = get_active_session()
session.sql("USE DATABASE BD_AA_DEV").collect()
session.sql("USE SCHEMA SC_STORAGE_BMX_PS").collect()

registry = Registry(
    session=session,
    database_name="BD_AA_DEV",
    schema_name="SC_MODELS_BMX"
)

INFERENCE_SAMPLE_FRACTION = 0.01

print("‚úÖ Connected to Snowflake")
print(f"   Database: {session.get_current_database()}")
print(f"   Schema: {session.get_current_schema()}")
if INFERENCE_SAMPLE_FRACTION:
    print(f"   ‚ö†Ô∏è  Sampling: {INFERENCE_SAMPLE_FRACTION*100:.1f}% del dataset de inferencia")


## 1. Verify model


In [None]:
print("\n" + "="*80)
print("üîç VERIFYING PARTITIONED MODEL")
print("="*80)

model_ref = registry.get_model("UNI_BOX_REGRESSION_PARTITIONED")
model_version = model_ref.version("PRODUCTION")
print(f"‚úÖ UNI_BOX_REGRESSION_PARTITIONED @ {model_version.version_name} (PRODUCTION)")


## 2. Load inference data


In [None]:
print("\n" + "="*80)
print("üè™ LOADING FEATURES")
print("="*80)

FEATURES_TABLE = "BD_AA_DEV.SC_FEATURES_BMX.UNI_BOX_FEATURES"

inference_spine = session.table("BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_DATASET_CLEANED").select(
    "customer_id", "brand_pres_ret", "week", "stats_ntile_group"
).distinct()

partition_col = next((c for c in inference_spine.columns if c.upper() == "STATS_NTILE_GROUP"), None)
if partition_col is None:
    raise ValueError("stats_ntile_group not found in inference dataset.")

if INFERENCE_SAMPLE_FRACTION and 0 < INFERENCE_SAMPLE_FRACTION < 1:
    inference_spine = inference_spine.sample(frac=INFERENCE_SAMPLE_FRACTION)

print(f"   Inference keys: {inference_spine.count():,}")
inference_spine.group_by(partition_col).count().sort(partition_col).show()

try:
    features_df = session.table(FEATURES_TABLE)
    inference_df = inference_spine.join(features_df, on=["customer_id", "brand_pres_ret", "week"], how="inner")
    n_records = inference_df.count()
except Exception as e:
    print(f"‚ùå Features table error: {str(e)[:200]}. Run 02_feature_store_setup.py first.")
    raise

if n_records == 0:
    print("‚ö†Ô∏è  No match between inference keys and feature table (UNI_BOX_FEATURES is built from training). Using TRAIN_DATASET_CLEANED as spine for this run.")
    train_spine = session.table("BD_AA_DEV.SC_STORAGE_BMX_PS.TRAIN_DATASET_CLEANED").select(
        "customer_id", "brand_pres_ret", "week", "stats_ntile_group"
    ).distinct()
    if INFERENCE_SAMPLE_FRACTION and 0 < INFERENCE_SAMPLE_FRACTION < 1:
        train_spine = train_spine.sample(frac=INFERENCE_SAMPLE_FRACTION)
    inference_df = train_spine.join(features_df, on=["customer_id", "brand_pres_ret", "week"], how="inner")
    n_records = inference_df.count()

print(f"‚úÖ Loaded {n_records:,} records")
inference_df.select("customer_id", "week", "brand_pres_ret", partition_col, "sum_past_12_weeks", "week_of_year").show(5)


## 3. Prepare inference input (misma l√≥gica que 02/04: excluir metadata, solo num√©ricas)


In [None]:
print("\n" + "="*80)
print("üîß PREPARING INFERENCE INPUT")
print("="*80)

# Misma exclusi√≥n que script 04 (train_segment_model): no son features
excluded_cols = [
    "customer_id", "brand_pres_ret", "week",
    "FEATURE_TIMESTAMP",
    partition_col,
]
excluded_upper = {c.upper() for c in excluded_cols}

inference_df.write.mode('overwrite').save_as_table('BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP')

temp_table = session.table('BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP')
actual_cols = temp_table.columns
temp_schema = session.sql("DESCRIBE TABLE BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP").collect()
col_type_dict = {row['name'].upper(): str(row['type']).upper() for row in temp_schema}

# Solo columnas num√©ricas y no excluidas, en el orden de la tabla (igual que 04 con dtype.kind in 'iufb')
NUMERIC_PREFIXES = ("FLOAT", "NUMBER", "INTEGER", "BIGINT", "DOUBLE")
feature_cols_actual = [
    c for c in actual_cols
    if c.upper() not in excluded_upper
    and (col_type_dict.get(c.upper()) or "").startswith(NUMERIC_PREFIXES)
]

customer_id_col = next((c for c in actual_cols if c.upper() == "CUSTOMER_ID"), "CUSTOMER_ID")
brand_col = next((c for c in actual_cols if c.upper() == "BRAND_PRES_RET"), "BRAND_PRES_RET")
week_col = next((c for c in actual_cols if c.upper() == "WEEK"), "WEEK")
partition_col_actual = next((c for c in actual_cols if c.upper() == partition_col.upper()), partition_col)

print(f"‚úÖ {len(feature_cols_actual)} features, partition: {partition_col_actual}")


## 4. Execute partitioned inference


In [None]:
print("\n" + "="*80)
print("üöÄ EXECUTING PARTITIONED INFERENCE")
print("="*80)

start_time = time.time()
# Pass columns as-is; model was registered with sample_input from training (same types).
feature_list = ", ".join(f"i.{col}" for col in feature_cols_actual)

predictions_sql = f"""
WITH model_predictions AS (
    SELECT 
        p.customer_id,
        p.{partition_col_actual},
        p.predicted_uni_box_week
    FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i,
        TABLE(
            BD_AA_DEV.SC_MODELS_BMX.UNI_BOX_REGRESSION_PARTITIONED!PREDICT(
                i.{customer_id_col},
                i.{partition_col_actual},
                {feature_list}
            ) OVER (PARTITION BY i.{partition_col_actual})
        ) p
)
SELECT 
    mp.customer_id,
    mp.{partition_col_actual},
    i.{week_col},
    i.{brand_col},
    ROUND(mp.predicted_uni_box_week, 2) AS predicted_uni_box_week
FROM model_predictions mp
JOIN BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i 
    ON mp.customer_id = i.{customer_id_col}
    AND mp.{partition_col_actual} = i.{partition_col_actual}
ORDER BY mp.{partition_col_actual}, mp.customer_id
"""

predictions_df = session.sql(predictions_sql)
prediction_count = predictions_df.count()
inference_time = time.time() - start_time

print(f"‚úÖ Done in {inference_time:.2f}s ‚Äî {prediction_count:,} predictions")
predictions_df.show(10)


## 5. Statistics


In [None]:
stats_sql = f"""
WITH model_predictions AS (
    SELECT 
        p.customer_id,
        p.predicted_uni_box_week
    FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i,
        TABLE(
            BD_AA_DEV.SC_MODELS_BMX.UNI_BOX_REGRESSION_PARTITIONED!PREDICT(
                i.{customer_id_col},
                i.{partition_col_actual},
                {feature_list}
            ) OVER (PARTITION BY i.{partition_col_actual})
        ) p
)
SELECT
    COUNT(*) AS TOTAL_PREDICTIONS,
    COUNT(DISTINCT customer_id) AS UNIQUE_CUSTOMERS,
    ROUND(MIN(predicted_uni_box_week), 2) AS MIN_PREDICTION,
    ROUND(MAX(predicted_uni_box_week), 2) AS MAX_PREDICTION,
    ROUND(AVG(predicted_uni_box_week), 2) AS AVG_PREDICTION,
    ROUND(STDDEV(predicted_uni_box_week), 2) AS STDDEV_PREDICTION,
    ROUND(PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY predicted_uni_box_week), 2) AS Q1,
    ROUND(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY predicted_uni_box_week), 2) AS MEDIAN,
    ROUND(PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY predicted_uni_box_week), 2) AS Q3
FROM model_predictions
"""

session.sql(stats_sql).show()


## 6. Save to INFERENCE_LOGS


In [None]:
session.sql("""
    CREATE TABLE IF NOT EXISTS BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS (
        customer_id VARCHAR,
        week VARCHAR,
        brand_pres_ret VARCHAR,
        stats_ntile_group VARCHAR,
        predicted_uni_box_week FLOAT,
        inference_timestamp TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
        model_version VARCHAR
    )
""").collect()
insert_sql = f"""
INSERT INTO BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS
    (customer_id, week, brand_pres_ret, stats_ntile_group, predicted_uni_box_week, model_version)
WITH model_predictions AS (
    SELECT 
        p.customer_id,
        p.{partition_col_actual},
        p.predicted_uni_box_week
    FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i,
        TABLE(
            BD_AA_DEV.SC_MODELS_BMX.UNI_BOX_REGRESSION_PARTITIONED!PREDICT(
                i.{customer_id_col},
                i.{partition_col_actual},
                {feature_list}
            ) OVER (PARTITION BY i.{partition_col_actual})
        ) p
)
SELECT 
    mp.customer_id,
    i.{week_col},
    i.{brand_col},
    mp.{partition_col_actual},
    mp.predicted_uni_box_week,
    '{model_version.version_name}'
FROM model_predictions mp
JOIN BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i 
    ON mp.customer_id = i.{customer_id_col}
    AND mp.{partition_col_actual} = i.{partition_col_actual}
"""

session.sql(insert_sql).collect()
log_count = session.sql("SELECT COUNT(*) as CNT FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS").collect()[0]['CNT']
print(f"‚úÖ Saved {log_count:,} to INFERENCE_LOGS")
session.sql("SELECT * FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS ORDER BY inference_timestamp DESC LIMIT 5").show()


## 7. Summary


In [None]:
print(f"\nüéâ Done ‚Äî {prediction_count:,} predictions, {inference_time:.2f}s, model {model_version.version_name}")
