In [None]:
# Databricks notebook source
# =============================================================================
# üöÄ PRODUCTION BATCH INFERENCE - CONFIG DRIVEN (FIXED)
# =============================================================================
# Purpose: Run batch inference using production serving endpoint
# Now reads from pipeline_config.yml - No hardcoding!
# Prerequisites: Run 07_create_serving_endpoint.py first
# =============================================================================

# COMMAND ----------
%pip install xgboost

# COMMAND ----------
dbutils.library.restartPython()

# COMMAND ----------
from databricks.sdk import WorkspaceClient
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np
from datetime import datetime
import math
import yaml
import sys
import traceback

print("=" * 80)
print("üöÄ PRODUCTION BATCH INFERENCE (CONFIG-DRIVEN)")
print("=" * 80)

# =============================================================================
# ‚úÖ LOAD PIPELINE CONFIGURATION (Dynamic Path)
# =============================================================================
print("\nüìã Loading pipeline configuration from pipeline_config.yml...")

import os, yaml, sys, traceback

try:
    # Detect current script path and repo root
    try:
        current_dir = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        current_dir = os.getcwd()
    project_root = os.path.abspath(os.path.join(current_dir, ".."))

    # Try same directory first
    config_path = os.path.join(current_dir, "pipeline_config.yml")

    # If not found, try dev_env folder
    if not os.path.exists(config_path):
        config_path = os.path.join(project_root, "dev_env", "pipeline_config.yml")

    if not os.path.exists(config_path):
        raise FileNotFoundError(f"pipeline_config.yml not found at {config_path}")

    with open(config_path, "r") as f:
        pipeline_cfg = yaml.safe_load(f)

    print(f"‚úÖ Loaded pipeline_config.yml from: {config_path}")

    # -----------------------------
    # Extract Config Parameters
    # -----------------------------
    MODEL_TYPE = pipeline_cfg["model"]["type"]
    BASE_NAME = pipeline_cfg["model"]["base_name"]

    ENDPOINT_NAME = f"{BASE_NAME.replace('_', '-')}-{MODEL_TYPE}-prod"

    # Data Configuration
    data_cfg = pipeline_cfg["data"]
    DATA_CATALOG, DATA_SCHEMA, INPUT_TABLE = data_cfg["input_table"].split(".")
    FEATURE_COLS = data_cfg["features"]
    LABEL_COL = data_cfg["label"]

    # Output Configuration
    OUTPUT_TABLE = f"{DATA_CATALOG}.{DATA_SCHEMA}.prod_inference_{MODEL_TYPE}"

    # Batch Configuration
    BATCH_SIZE = pipeline_cfg.get("inference", {}).get("batch_size", 100)

    # Print summary
    print(f"‚úÖ Configuration loaded successfully!")
    print(f"\nüìä Configuration Details:")
    print(f"   Model Type: {MODEL_TYPE.upper()}")
    print(f"   Endpoint Name: {ENDPOINT_NAME}")
    print(f"   Input Table: {DATA_CATALOG}.{DATA_SCHEMA}.{INPUT_TABLE}")
    print(f"   Output Table: {OUTPUT_TABLE}")
    print(f"   Features: {FEATURE_COLS}")
    print(f"   Label: {LABEL_COL}")
    print(f"   Batch Size: {BATCH_SIZE}")

except FileNotFoundError as e:
    print(f"‚ùå ERROR: {e}")
    print("üí° Please ensure pipeline_config.yml exists in the same or dev_env directory.")
    dbutils.notebook.exit("CONFIG_NOT_FOUND")

except Exception as e:
    print(f"‚ùå ERROR loading configuration: {e}")
    traceback.print_exc()
    dbutils.notebook.exit(f"CONFIG_ERROR: {e}")


print("=" * 80)

# =============================================================================
# ‚úÖ CLIENTS INITIALIZATION
# =============================================================================
try:
    w = WorkspaceClient()
    spark = SparkSession.builder.appName(f"{MODEL_TYPE.upper()}_Inference").getOrCreate()
    print("\n‚úÖ Workspace & Spark initialized")
except Exception as e:
    print(f"‚ùå Initialization failed: {e}")
    dbutils.notebook.exit(f"INIT_FAILED: {e}")

# =============================================================================
# ‚úÖ STEP 1: VERIFY ENDPOINT STATUS
# =============================================================================
print(f"\n{'='*80}")
print("üìã STEP 1: Verifying Endpoint Readiness")
print(f"{'='*80}")
print(f"üîç Endpoint: {ENDPOINT_NAME}")

try:
    endpoint = w.serving_endpoints.get(name=ENDPOINT_NAME)
    
    if hasattr(endpoint, 'state') and hasattr(endpoint.state, 'ready'):
        if endpoint.state.ready:
            print(f"‚úÖ Endpoint '{ENDPOINT_NAME}' is READY")
            print(f"   State: {endpoint.state}")
        else:
            print(f"‚ö†Ô∏è Endpoint NOT fully ready")
            print(f"   State: {endpoint.state}")
            print(f"   Proceeding cautiously...")
    else:
        print(f"‚ö†Ô∏è Cannot determine endpoint state ‚Üí Proceeding cautiously")
        
except Exception as e:
    error_msg = f"‚ùå Endpoint error: {e}"
    print(error_msg)
    print(f"\nüí° Troubleshooting:")
    print(f"   1. Verify endpoint exists: {ENDPOINT_NAME}")
    print(f"   2. Run 07_create_serving_endpoint.py first")
    print(f"   3. Check endpoint status in Databricks UI")
    traceback.print_exc()
    dbutils.notebook.exit(f"ENDPOINT_ERROR: {e}")

# =============================================================================
# ‚úÖ STEP 2: LOAD INPUT DATA
# =============================================================================
print(f"\n{'='*80}")
print("üìã STEP 2: Loading Input Data")
print(f"{'='*80}")
print(f"üîç Table: {DATA_CATALOG}.{DATA_SCHEMA}.{INPUT_TABLE}")

try:
    df_spark = spark.read.format("delta").table(
        f"{DATA_CATALOG}.{DATA_SCHEMA}.{INPUT_TABLE}"
    )
    df = df_spark.toPandas()
    
    print(f"‚úÖ Data loaded successfully")
    print(f"   Total rows: {len(df):,}")
    print(f"   Columns: {list(df.columns)}")
    
    # Validate required columns
    missing_features = [col for col in FEATURE_COLS if col not in df.columns]
    if missing_features:
        raise ValueError(f"Missing feature columns: {missing_features}")
    
    # Extract label if available
    y_true = df[LABEL_COL] if LABEL_COL in df.columns else None
    if y_true is not None:
        print(f"   Label column '{LABEL_COL}' found ‚Üí Will calculate metrics")
    else:
        print(f"   Label column '{LABEL_COL}' not found ‚Üí No metrics")
    
    print(f"   Feature columns: {FEATURE_COLS}")
    
except Exception as e:
    error_msg = f"‚ùå Data loading failed: {e}"
    print(error_msg)
    print(f"\nüí° Troubleshooting:")
    print(f"   1. Verify table exists: {DATA_CATALOG}.{DATA_SCHEMA}.{INPUT_TABLE}")
    print(f"   2. Check feature columns in pipeline_config.yml")
    traceback.print_exc()
    dbutils.notebook.exit(f"DATA_LOAD_FAILED: {e}")

# =============================================================================
# ‚úÖ STEP 3: BATCH INFERENCE
# =============================================================================
print(f"\n{'='*80}")
print("üìã STEP 3: Running Batch Inference")
print(f"{'='*80}")

all_predictions = []
num_batches = (len(df) + BATCH_SIZE - 1) // BATCH_SIZE

print(f"üîÑ Processing {len(df)} rows in {num_batches} batch(es) of size {BATCH_SIZE}")

try:
    for batch_idx in range(num_batches):
        start = batch_idx * BATCH_SIZE
        end = min((batch_idx + 1) * BATCH_SIZE, len(df))
        batch_df = df.iloc[start:end][FEATURE_COLS]

        # Call serving endpoint
        response = w.serving_endpoints.query(
            name=ENDPOINT_NAME,
            dataframe_records=batch_df.to_dict("records")
        )

        predictions = response.predictions
        all_predictions.extend(predictions)

        print(f"   ‚úì Batch {batch_idx+1}/{num_batches} complete ({len(predictions)} predictions)")

    print(f"\n‚úÖ Inference complete!")
    print(f"   Total predictions: {len(all_predictions):,}")
    print(f"   Sample predictions: {all_predictions[:5]}")

except Exception as e:
    error_msg = f"‚ùå Inference failed: {e}"
    print(error_msg)
    traceback.print_exc()
    dbutils.notebook.exit(f"INFERENCE_FAILED: {e}")

# =============================================================================
# ‚úÖ STEP 4: PREPARE RESULTS
# =============================================================================
print(f"\n{'='*80}")
print("üìã STEP 4: Preparing Results")
print(f"{'='*80}")

# Add predictions and metadata to dataframe
df["predicted_price"] = all_predictions
df["prediction_timestamp"] = datetime.now()
df["endpoint_name"] = ENDPOINT_NAME
df["model_type"] = MODEL_TYPE.upper()
df["inference_method"] = "serving_endpoint"
df["prediction_date"] = datetime.now().strftime('%Y-%m-%d')
df["batch_id"] = datetime.now().strftime('%Y-%m-%d_%H%M%S')

print(f"‚úÖ Results prepared")
print(f"   Rows with predictions: {len(df):,}")

# =============================================================================
# ‚úÖ STEP 5: SAVE RESULTS TO DELTA (AVOID DUPLICATES)
# =============================================================================
print(f"\n{'='*80}")
print("üìã STEP 5: Saving Results to Delta")
print(f"{'='*80}")
print(f"üîç Output table: {OUTPUT_TABLE}")

try:
    spark_df = spark.createDataFrame(df)

    # Check for duplicates
    try:
        prev_df = spark.read.table(OUTPUT_TABLE).toPandas()
        if len(prev_df) > 0:
            last_pred_prev = prev_df["predicted_price"].iloc[-1]
            last_pred_new = df["predicted_price"].iloc[-1]
            
            if math.isclose(last_pred_prev, last_pred_new, rel_tol=1e-6):
                print("‚ÑπÔ∏è Duplicate batch detected ‚Äî skipping save")
                dbutils.notebook.exit("SKIPPED_DUPLICATE")
    except:
        print("‚ÑπÔ∏è Output table does not exist ‚Üí Creating new one")

    # Write to Delta
    spark_df.write.mode("append")\
        .format("delta")\
        .option("mergeSchema", "true")\
        .saveAsTable(OUTPUT_TABLE)
    
    print(f"‚úÖ Results saved to {OUTPUT_TABLE}")

except Exception as e:
    print(f"‚ö†Ô∏è Save warning: {e}")
    print("   Continuing with metrics calculation...")
    traceback.print_exc()

# =============================================================================
# ‚úÖ STEP 6: CALCULATE PERFORMANCE METRICS (IF LABELS AVAILABLE)
# =============================================================================
if y_true is not None and len(y_true) > 0:
    print(f"\n{'='*80}")
    print("üìã STEP 6: Calculating Performance Metrics")
    print(f"{'='*80}")
    
    try:
        y_pred = df["predicted_price"]

        mae = np.mean(np.abs(y_true - y_pred))
        rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
        mape = np.mean(np.abs((y_true - y_pred) / np.where(y_true==0, 1, y_true))) * 100
        
        from sklearn.metrics import r2_score
        r2 = r2_score(y_true, y_pred)

        print(f"\nüìä Production Inference Metrics:")
        print(f"   MAE  : {mae:,.3f}")
        print(f"   RMSE : {rmse:,.3f}")
        print(f"   MAPE : {mape:.2f}%")
        print(f"   R¬≤   : {r2:.4f}")
        
    except Exception as e:
        print(f"‚ö†Ô∏è Metrics calculation failed: {e}")
        traceback.print_exc()
else:
    print(f"\n‚ÑπÔ∏è No ground truth labels available - skipping metrics")

# =============================================================================
# ‚úÖ FINAL SUMMARY
# =============================================================================
print(f"\n{'='*80}")
print("‚úÖ‚úÖ PRODUCTION INFERENCE COMPLETE ‚úÖ‚úÖ")
print(f"{'='*80}")
print(f"\nüìä Execution Summary:")
print(f"   Model Type: {MODEL_TYPE.upper()}")
print(f"   Endpoint: {ENDPOINT_NAME}")
print(f"   Input Table: {DATA_CATALOG}.{DATA_SCHEMA}.{INPUT_TABLE}")
print(f"   Output Table: {OUTPUT_TABLE}")
print(f"   Rows Processed: {len(df):,}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Timestamp: {datetime.now()}")
print(f"{'='*80}")

# Save for workflow
try:
    dbutils.jobs.taskValues.set(key="inference_rows", value=len(df))
    dbutils.jobs.taskValues.set(key="output_table", value=OUTPUT_TABLE)
    print("\n‚úÖ Task values saved for workflow")
except:
    print("\n‚ÑπÔ∏è Not running in workflow - skipping task values")

dbutils.notebook.exit("SUCCESS")