# Migration: Partitioned Inference Batch

## Overview
This script executes batch inference using the partitioned model with partitioned inference syntax.

## What We'll Do:
1. Load inference data from cleaned table
2. Prepare features for inference
3. Execute partitioned inference using TABLE(...) OVER (PARTITION BY ...) syntax
4. Save predictions to inference logs
5. Generate statistics


In [None]:
from snowflake.snowpark.context import get_active_session
from snowflake.ml.registry import Registry
from snowflake.snowpark import functions as F
import pandas as pd
import time

session = get_active_session()

# Set context
session.sql("USE WAREHOUSE ARCA_DEMO_WH").collect()
session.sql("USE DATABASE BD_AA_DEV").collect()
session.sql("USE SCHEMA SC_STORAGE_BMX_PS").collect()

registry = Registry(
    session=session,
    database_name="BD_AA_DEV",
    schema_name="MODEL_REGISTRY"
)

print("‚úÖ Connected to Snowflake")
print(f"   Database: {session.get_current_database()}")
print(f"   Schema: {session.get_current_schema()}")


## 1. Verify Partitioned Model


In [None]:
print("\n" + "="*80)
print("üîç VERIFYING PARTITIONED MODEL")
print("="*80)

model_ref = registry.get_model("UNI_BOX_REGRESSION_PARTITIONED")
model_version = model_ref.version("PRODUCTION")

print("‚úÖ Model: UNI_BOX_REGRESSION_PARTITIONED")
print(f"   Version: {model_version.version_name}")
print(f"   Alias: PRODUCTION")

# Show model functions
functions = session.sql("""
    SHOW FUNCTIONS IN MODEL BD_AA_DEV.MODEL_REGISTRY.UNI_BOX_REGRESSION_PARTITIONED
""").collect()

print(f"\nüìã Available functions:")
for f in functions:
    print(f"   - {f['name']}")


## 2. Load Inference Data


In [None]:
print("\n" + "="*80)
print("üìä LOADING INFERENCE DATA")
print("="*80)

# Load cleaned inference data
inference_df = session.table("BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_DATASET_CLEANED")

print(f"\n‚úÖ Inference data loaded")
print(f"   Total records: {inference_df.count():,}")
print(f"   Unique customers: {inference_df.select('customer_id').distinct().count():,}")

# Show sample
print("\nüìã Sample inference data:")
inference_df.select(
    'customer_id', 'week', 'brand_pres_ret', 
    'sum_past_12_weeks', 'week_of_year'
).show(5)


## 3. Prepare Inference Input


In [None]:
print("\n" + "="*80)
print("üîß PREPARING INFERENCE INPUT")
print("="*80)

# Define excluded columns
excluded_cols = [
    'customer_id', 'brand_pres_ret', 'week', 
    'group', 'stats_group', 'percentile_group', 'stats_ntile_group'
]

# Get feature columns (same as training)
inference_columns = inference_df.columns
feature_cols = [col for col in inference_columns 
                if col not in excluded_cols]

print(f"\nüìã Features for inference ({len(feature_cols)}):")
for col in sorted(feature_cols):
    print(f"   - {col}")

# Create a dummy partition column for partitioned inference
# Since we have a single model, we can use a constant partition
inference_input = inference_df.with_column("dummy_partition", F.lit("ALL"))

# Save to temporary table for inference
inference_input.write.mode('overwrite').save_as_table(
    'BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP'
)

print(f"\n‚úÖ Inference input prepared and saved to temporary table")
print(f"   Records: {inference_input.count():,}")


## 4. Execute Partitioned Inference


In [None]:
print("\n" + "="*80)
print("üöÄ EXECUTING PARTITIONED INFERENCE")
print("="*80)

print("\nüìù Running partitioned inference...")
print("   Syntax: TABLE(model!PREDICT(...) OVER (PARTITION BY dummy_partition))")
print("   This enables partitioned inference even with a single model\n")

start_time = time.time()

# Build feature list for PREDICT function
# We need to pass features in the same order as training
feature_list = ", ".join([f"i.{col}" for col in feature_cols])

predictions_sql = f"""
WITH model_predictions AS (
    SELECT 
        p.customer_id,
        p.predicted_uni_box_week
    FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i,
        TABLE(
            BD_AA_DEV.MODEL_REGISTRY.UNI_BOX_REGRESSION_PARTITIONED!PREDICT(
                i.customer_id,
                {feature_list}
            ) OVER (PARTITION BY i.dummy_partition)
        ) p
)
SELECT 
    mp.customer_id,
    i.week,
    i.brand_pres_ret,
    ROUND(mp.predicted_uni_box_week, 2) AS predicted_uni_box_week
FROM model_predictions mp
JOIN BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i 
    ON mp.customer_id = i.customer_id
ORDER BY mp.customer_id
"""

predictions_df = session.sql(predictions_sql)
prediction_count = predictions_df.count()
inference_time = time.time() - start_time

print(f"‚úÖ Inference complete!")
print(f"   ‚è±Ô∏è  Time: {inference_time:.2f} seconds")
print(f"   üìä Predictions: {prediction_count:,}")

print("\nüìä Sample Predictions:")
predictions_df.show(10)


## 5. Analyze Prediction Statistics


In [None]:
print("\n" + "="*80)
print("üìà PREDICTION STATISTICS")
print("="*80)

stats_sql = """
WITH model_predictions AS (
    SELECT 
        p.customer_id,
        p.predicted_uni_box_week
    FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i,
        TABLE(
            BD_AA_DEV.MODEL_REGISTRY.UNI_BOX_REGRESSION_PARTITIONED!PREDICT(
                i.customer_id,
                i.sum_past_12_weeks,
                i.avg_past_12_weeks,
                i.max_past_24_weeks,
                i.sum_past_24_weeks,
                i.week_of_year,
                i.avg_avg_daily_all_hours,
                i.sum_p4w,
                i.avg_past_24_weeks,
                i.pharm_super_conv,
                i.wines_liquor,
                i.groceries,
                i.max_prev2,
                i.avg_prev2,
                i.max_prev3,
                i.avg_prev3,
                i.w_m1_total,
                i.w_m2_total,
                i.w_m3_total,
                i.w_m4_total,
                i.spec_foods,
                i.prod_key,
                i.num_coolers,
                i.num_doors,
                i.max_past_4_weeks,
                i.sum_past_4_weeks,
                i.avg_past_4_weeks,
                i.max_past_12_weeks
            ) OVER (PARTITION BY i.dummy_partition)
        ) p
)
SELECT
    COUNT(*) AS TOTAL_PREDICTIONS,
    COUNT(DISTINCT customer_id) AS UNIQUE_CUSTOMERS,
    ROUND(MIN(predicted_uni_box_week), 2) AS MIN_PREDICTION,
    ROUND(MAX(predicted_uni_box_week), 2) AS MAX_PREDICTION,
    ROUND(AVG(predicted_uni_box_week), 2) AS AVG_PREDICTION,
    ROUND(STDDEV(predicted_uni_box_week), 2) AS STDDEV_PREDICTION,
    ROUND(PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY predicted_uni_box_week), 2) AS Q1,
    ROUND(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY predicted_uni_box_week), 2) AS MEDIAN,
    ROUND(PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY predicted_uni_box_week), 2) AS Q3
FROM model_predictions
"""

print("\nüìä Overall Statistics:")
session.sql(stats_sql).show()


## 6. Save Predictions to Inference Logs


In [None]:
print("\n" + "="*80)
print("üíæ SAVING PREDICTIONS TO INFERENCE LOGS")
print("="*80)

# Create inference logs table
session.sql("""
    CREATE TABLE IF NOT EXISTS BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS (
        customer_id VARCHAR,
        week VARCHAR,
        brand_pres_ret VARCHAR,
        predicted_uni_box_week FLOAT,
        inference_timestamp TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
        model_version VARCHAR
    )
""").collect()

# Clear previous logs (optional - for demo purposes)
# session.sql("DELETE FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS").collect()

# Insert predictions
insert_sql = f"""
INSERT INTO BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS
    (customer_id, week, brand_pres_ret, predicted_uni_box_week, model_version)
WITH model_predictions AS (
    SELECT 
        p.customer_id,
        p.predicted_uni_box_week
    FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i,
        TABLE(
            BD_AA_DEV.MODEL_REGISTRY.UNI_BOX_REGRESSION_PARTITIONED!PREDICT(
                i.customer_id,
                {feature_list}
            ) OVER (PARTITION BY i.dummy_partition)
        ) p
)
SELECT 
    mp.customer_id,
    i.week,
    i.brand_pres_ret,
    mp.predicted_uni_box_week,
    '{model_version.version_name}'
FROM model_predictions mp
JOIN BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_INPUT_TEMP i 
    ON mp.customer_id = i.customer_id
"""

session.sql(insert_sql).collect()

log_count = session.sql("SELECT COUNT(*) as CNT FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS").collect()[0]['CNT']
print(f"‚úÖ Saved {log_count:,} predictions to INFERENCE_LOGS")

print("\nüìã Sample from logs:")
session.sql("""
    SELECT * FROM BD_AA_DEV.SC_STORAGE_BMX_PS.INFERENCE_LOGS 
    ORDER BY inference_timestamp DESC
    LIMIT 5
""").show()


## 7. Summary


In [None]:
print("\n" + "="*80)
print("üéâ PARTITIONED INFERENCE BATCH COMPLETE!")
print("="*80)

print(f"""
üìä Summary:
   ‚úÖ Predictions generated: {prediction_count:,}
   ‚úÖ Inference time: {inference_time:.2f} seconds
   ‚úÖ Logs saved to: INFERENCE_LOGS
   ‚úÖ Model version: {model_version.version_name}

üí° Key Advantages of Partitioned Model:
   ‚úÖ Single model with partitioned API
   ‚úÖ Consistent inference syntax
   ‚úÖ Ready for future multi-model scenarios
   ‚úÖ SQL-native inference (no Python required)
   ‚úÖ Parallel execution handled by Snowflake

üéØ Business Impact:
   ‚Ä¢ Batch predictions for all inference records
   ‚Ä¢ Predictions stored for monitoring and analysis
   ‚Ä¢ Ready for production deployment
   ‚Ä¢ Scalable to multiple models if needed

üöÄ Next Steps:
   ‚Üí Review predictions in INFERENCE_LOGS table
   ‚Üí Set up monitoring and observability
   ‚Üí Schedule regular batch inference runs
""")

print("="*80)
