In [None]:
# Databricks notebook source
# ================================================================
# üöÄ SMART PRODUCTION BATCH INFERENCE (VIA SERVING ENDPOINT)
#    Auto-adaptive model type, dynamic endpoint + result table
# ================================================================
%pip install xgboost
dbutils.library.restartPython()

from databricks.sdk import WorkspaceClient
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np
from datetime import datetime
import sys, math

print("=" * 80)
print("üöÄ SMART PRODUCTION BATCH INFERENCE VIA SERVING ENDPOINT")
print("=" * 80)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG = "workspace"
UC_SCHEMA = "ml"
DATA_CATALOG = "workspace"
DATA_SCHEMA = "default"
INPUT_TABLE = "house_price_delta"

# Default batch size
BATCH_SIZE = 100

# Metric thresholds (for quick performance checks)
MAPE_THRESHOLD = 15.0
R2_THRESHOLD = 0.75

# =============================================================================
# INITIALIZE CLIENTS
# =============================================================================
try:
    w = WorkspaceClient()
    print("‚úì Databricks Workspace Client initialized")

    spark = SparkSession.builder.appName("SmartBatchInferenceProd").getOrCreate()
    print("‚úì Spark session initialized")
except Exception as e:
    print(f"‚ùå Error initializing clients: {e}")
    sys.exit(1)

# =============================================================================
# AUTO-DETECT MODEL TYPE & ENDPOINT NAME
# =============================================================================
try:
    # Search latest model in UC (based on update time)
    import mlflow
    from mlflow.tracking import MlflowClient
    client = MlflowClient()
    mlflow.set_registry_uri("databricks-uc")

    experiments = client.search_experiments(view_type=mlflow.entities.ViewType.ACTIVE_ONLY)
    latest_exp = max(experiments, key=lambda exp: exp.last_update_time)
    exp_name = latest_exp.name.lower()

    if "xgboost" in exp_name:
        model_type = "xgboost"
    elif "rf" in exp_name or "randomforest" in exp_name:
        model_type = "rf"
    elif "linear" in exp_name:
        model_type = "linear"
    else:
        model_type = "generic"

    model_name = f"{UC_CATALOG}.{UC_SCHEMA}.house_price_{model_type}_uc"
    ENDPOINT_NAME = f"house-price-{model_type}-prod"
    OUTPUT_TABLE = f"{DATA_CATALOG}.{DATA_SCHEMA}.prod_inference_{model_type}"

    print(f"üìò Latest Experiment: {latest_exp.name}")
    print(f"‚úÖ Detected Model Type: {model_type.upper()}")
    print(f"‚úÖ Using Endpoint: {ENDPOINT_NAME}")
    print(f"‚úÖ Output Table: {OUTPUT_TABLE}")

except Exception as e:
    print(f"‚ùå Error detecting model type or endpoint: {e}")
    sys.exit(1)

# =============================================================================
# CHECK ENDPOINT STATUS
# =============================================================================
print("\nüîç Checking endpoint readiness...")
try:
    endpoint = w.serving_endpoints.get(name=ENDPOINT_NAME)

    if endpoint.state and "READY" in str(endpoint.state.ready):
        print(f"‚úÖ Endpoint '{ENDPOINT_NAME}' is READY")
    else:
        print(f"‚ö†Ô∏è Endpoint '{ENDPOINT_NAME}' may not be ready. Proceeding with caution...")

    if endpoint.config and endpoint.config.served_entities:
        for entity in endpoint.config.served_entities:
            print(f"   ‚Ä¢ Model: {entity.entity_name} | Version: {entity.entity_version}")

except Exception as e:
    print(f"‚ùå Cannot access endpoint '{ENDPOINT_NAME}': {e}")
    sys.exit(1)

# =============================================================================
# LOAD INPUT DATA
# =============================================================================
print("\nüì¶ Loading input data...")
FULL_INPUT_TABLE = f"{DATA_CATALOG}.{DATA_SCHEMA}.{INPUT_TABLE}"

try:
    df_spark = spark.read.format("delta").table(FULL_INPUT_TABLE)
    df = df_spark.toPandas()
    print(f"‚úÖ Loaded {len(df)} records from {FULL_INPUT_TABLE}")
except Exception as e:
    print(f"‚ùå Failed to load input Delta table: {e}")
    sys.exit(1)

if "price" not in df.columns:
    print("‚ö†Ô∏è No 'price' column found ‚Äî proceeding with inference-only mode")
    y_true = None
else:
    y_true = df["price"]

# Infer feature columns automatically (exclude target and identifiers)
FEATURE_COLUMNS = [c for c in df.columns if c not in ["price", "id", "timestamp"]]
print(f"üîç Using features: {', '.join(FEATURE_COLUMNS)}")

# =============================================================================
# MAKE PREDICTIONS VIA ENDPOINT
# =============================================================================
print("\nüöÄ Performing batch inference via serving endpoint...")
all_predictions = []
num_batches = (len(df) + BATCH_SIZE - 1) // BATCH_SIZE

try:
    for batch_idx in range(num_batches):
        start = batch_idx * BATCH_SIZE
        end = min((batch_idx + 1) * BATCH_SIZE, len(df))
        batch = df.iloc[start:end]
        batch_records = batch[FEATURE_COLUMNS].to_dict('records')

        # Query serving endpoint
        response = w.serving_endpoints.query(name=ENDPOINT_NAME, dataframe_records=batch_records)

        # Extract predictions
        predictions = response.predictions if hasattr(response, 'predictions') else response
        all_predictions.extend(predictions)

        if (batch_idx + 1) % 5 == 0 or batch_idx == num_batches - 1:
            print(f"   Processed {end}/{len(df)} samples...")

    df["predicted_price"] = all_predictions
    df["prediction_timestamp"] = datetime.now()
    df["endpoint_name"] = ENDPOINT_NAME
    df["inference_method"] = "serving_endpoint"

    print(f"‚úÖ Predictions complete: {len(df)} rows")
    print(f"   Range: {min(all_predictions):,.2f} - {max(all_predictions):,.2f}")
    print(f"   Mean: {np.mean(all_predictions):,.2f}")

except Exception as e:
    print(f"‚ùå Error during inference: {e}")
    sys.exit(1)

# =============================================================================
# SAVE RESULTS TO DELTA (with duplicate check)
# =============================================================================
print("\nüíæ Saving predictions to Delta table...")

try:
    prediction_date = datetime.now().strftime('%Y-%m-%d')
    batch_id = f"{prediction_date}_{datetime.now().strftime('%H%M%S')}"

    df["prediction_date"] = prediction_date
    df["batch_id"] = batch_id

    spark_df = spark.createDataFrame(df)

    table_exists = False
    try:
        existing_df = spark.read.table(OUTPUT_TABLE).toPandas()
        table_exists = True
        if not existing_df.empty:
            last = existing_df.iloc[-1]
            if math.isclose(last.predicted_price, df.iloc[-1].predicted_price, rel_tol=1e-6):
                print(f"‚ÑπÔ∏è Duplicate predictions detected, skipping save.")
                sys.exit(0)
    except Exception:
        print(f"‚ÑπÔ∏è Table does not exist yet. Creating new...")

    mode = "append" if table_exists else "overwrite"
    spark_df.write.mode(mode).format("delta").option("mergeSchema", "true").saveAsTable(OUTPUT_TABLE)

    print(f"‚úÖ Saved predictions to {OUTPUT_TABLE} (mode={mode.upper()})")

except Exception as e:
    print(f"‚ùå Save operation failed: {e}")
    sys.exit(1)

# =============================================================================
# PERFORMANCE MONITORING
# =============================================================================
if y_true is not None:
    print("\nüìä Evaluating model performance...")
    y_pred = df["predicted_price"]

    mae = np.mean(np.abs(y_true - y_pred))
    mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))

    print(f"   MAE  : {mae:,.2f}")
    print(f"   RMSE : {rmse:,.2f}")
    print(f"   MAPE : {mape:.2f}%")

    if mape > MAPE_THRESHOLD:
        print(f"‚ö†Ô∏è WARNING: MAPE {mape:.2f}% exceeds {MAPE_THRESHOLD}% threshold!")
    else:
        print(f"‚úÖ Model performance within acceptable range.")

# =============================================================================
# SUMMARY
# =============================================================================
print("\n" + "=" * 80)
print("üéØ PRODUCTION BATCH INFERENCE COMPLETE")
print("=" * 80)
print(f"Model Type     : {model_type.upper()}")
print(f"Endpoint Used  : {ENDPOINT_NAME}")
print(f"Output Table   : {OUTPUT_TABLE}")
print(f"Records        : {len(df)}")
print(f"Timestamp      : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 80)
print("\nBENEFITS:")
print("  ‚Ä¢ Unified model deployment (same as API)")
print("  ‚Ä¢ Environment consistency via UC endpoint")
print("  ‚Ä¢ Automatic endpoint + table detection")
print("  ‚Ä¢ Smart duplicate prevention")
print("  ‚Ä¢ Continuous production monitoring\n")

try:
    dbutils.notebook.exit("INFERENCE_SUCCESS")
except:
    pass
