In [None]:
# Databricks notebook source
# ================================================================
# üöÄ SMART PRODUCTION BATCH INFERENCE (SERVING ENDPOINT VERSION)
#    - Auto-detects model type
#    - Uses UC-based production serving endpoint
#    - Saves predictions to Delta with duplicate protection
# ================================================================

%pip install xgboost
dbutils.library.restartPython()

from databricks.sdk import WorkspaceClient
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np
from datetime import datetime
import math
import mlflow
from mlflow.tracking import MlflowClient

print("=" * 80)
print("üöÄ SMART PRODUCTION BATCH INFERENCE")
print("=" * 80)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG = "workspace"
UC_SCHEMA = "ml"
DATA_CATALOG = "workspace"
DATA_SCHEMA = "default"
INPUT_TABLE = "house_price_delta"

BATCH_SIZE = 100

MAPE_THRESHOLD = 15.0
R2_THRESHOLD = 0.75

# =============================================================================
# INIT CLIENTS
# =============================================================================
try:
    w = WorkspaceClient()
    spark = SparkSession.builder.appName("ProdInference").getOrCreate()
    print("‚úì Workspace & Spark initialized")
except Exception as e:
    dbutils.notebook.exit(f"INIT_FAILED: {e}")

# =============================================================================
# DETECT MODEL TYPE + ENDPOINT
# =============================================================================
try:
    mlflow.set_registry_uri("databricks-uc")
    client = MlflowClient()

    experiments = client.search_experiments(view_type=mlflow.entities.ViewType.ACTIVE_ONLY)
    latest_exp = max(experiments, key=lambda exp: exp.last_update_time)
    exp_name = latest_exp.name.lower()

    if "xgboost" in exp_name:
        model_type = "xgboost"
    elif "rf" in exp_name or "randomforest" in exp_name:
        model_type = "rf"
    elif "linear" in exp_name:
        model_type = "linear"
    else:
        model_type = "generic"

    model_name = f"{UC_CATALOG}.{UC_SCHEMA}.house_price_{model_type}_uc"
    ENDPOINT_NAME = f"house-price-{model_type}-prod"
    OUTPUT_TABLE = f"{DATA_CATALOG}.{DATA_SCHEMA}.prod_inference_{model_type}"

    print(f"üìò Model Type     : {model_type.upper()}")
    print(f"‚úÖ Using Endpoint : {ENDPOINT_NAME}")
    print(f"‚úÖ Output Table   : {OUTPUT_TABLE}")

except Exception as e:
    dbutils.notebook.exit(f"MODEL_DETECTION_FAILED: {e}")

# =============================================================================
# VERIFY ENDPOINT STATUS
# =============================================================================
print("\nüîç Checking endpoint readiness...")

try:
    endpoint = w.serving_endpoints.get(name=ENDPOINT_NAME)

    if endpoint.state.ready:
        print(f"‚úÖ Endpoint '{ENDPOINT_NAME}' is READY")
    else:
        print(f"‚ö†Ô∏è Endpoint NOT fully ready ‚Üí Proceeding cautiously")

    for m in endpoint.config.served_entities:
        print(f"   ‚Ä¢ {m.entity_name} ‚Üí Version {m.entity_version}")

except Exception as e:
    dbutils.notebook.exit(f"ENDPOINT_ERROR: {e}")

# =============================================================================
# LOAD INPUT DATA
# =============================================================================
print("\nüì¶ Loading input data...")

try:
    df_spark = spark.read.format("delta").table(f"{DATA_CATALOG}.{DATA_SCHEMA}.{INPUT_TABLE}")
    df = df_spark.toPandas()
    print(f"‚úÖ Loaded {len(df)} records")
except Exception as e:
    dbutils.notebook.exit(f"DATA_LOAD_FAILED: {e}")

y_true = df["price"] if "price" in df.columns else None

FEATURE_COLUMNS = [c for c in df.columns if c not in ["price", "id", "timestamp"]]

print(f"üîç Feature Columns: {FEATURE_COLUMNS}")

# =============================================================================
# BATCH INFERENCE
# =============================================================================
print("\nüöÄ Performing inference...")

all_predictions = []
num_batches = (len(df) + BATCH_SIZE - 1) // BATCH_SIZE

try:
    for batch_idx in range(num_batches):
        start = batch_idx * BATCH_SIZE
        end = min((batch_idx + 1) * BATCH_SIZE, len(df))
        batch_df = df.iloc[start:end][FEATURE_COLUMNS]

        response = w.serving_endpoints.query(
            name=ENDPOINT_NAME,
            dataframe_records=batch_df.to_dict("records")
        )

        predictions = response.predictions
        all_predictions.extend(predictions)

        print(f"   ‚Üí Batch {batch_idx+1}/{num_batches} complete")

except Exception as e:
    dbutils.notebook.exit(f"INFERENCE_FAILED: {e}")

df["predicted_price"] = all_predictions
df["prediction_timestamp"] = datetime.now()
df["endpoint_name"] = ENDPOINT_NAME
df["inference_method"] = "serving_endpoint"

print(f"‚úÖ Generated {len(all_predictions)} predictions")

# =============================================================================
# SAVE RESULTS TO DELTA (AVOID DUPLICATES)
# =============================================================================
print("\nüíæ Saving predictions...")

try:
    df["prediction_date"] = datetime.now().strftime('%Y-%m-%d')
    df["batch_id"] = datetime.now().strftime('%Y-%m-%d_%H%M%S')

    spark_df = spark.createDataFrame(df)

    # Duplicate check on last prediction row
    try:
        prev_df = spark.read.table(OUTPUT_TABLE).toPandas()
        last_pred_prev = prev_df["predicted_price"].iloc[-1]
        last_pred_new = df["predicted_price"].iloc[-1]

        if math.isclose(last_pred_prev, last_pred_new, rel_tol=1e-6):
            print("‚ÑπÔ∏è Duplicate batch detected ‚Äî skipping save")
            dbutils.notebook.exit("SKIPPED_DUPLICATE")
    except:
        print("‚ÑπÔ∏è Output table does not exist ‚Üí Creating new one")

    spark_df.write.mode("append").format("delta").option("mergeSchema", "true").saveAsTable(OUTPUT_TABLE)
    print(f"‚úÖ Saved to {OUTPUT_TABLE}")

except Exception as e:
    dbutils.notebook.exit(f"SAVE_FAILED: {e}")

# =============================================================================
# PERFORMANCE METRICS
# =============================================================================
if y_true is not None:
    print("\nüìä Evaluating model performance...")
    y_pred = df["predicted_price"]

    mae = np.mean(np.abs(y_true - y_pred))
    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
    mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100

    print(f"   MAE  : {mae:.3f}")
    print(f"   RMSE : {rmse:.3f}")
    print(f"   MAPE : {mape:.2f}%")

# =============================================================================
# SUMMARY
# =============================================================================
print("\n" + "=" * 80)
print("üéØ PRODUCTION INFERENCE COMPLETE")
print("=" * 80)
print(f"Model Type     : {model_type.upper()}")
print(f"Endpoint Used  : {ENDPOINT_NAME}")
print(f"Output Table   : {OUTPUT_TABLE}")
print(f"Rows Processed : {len(df)}")
print(f"Timestamp      : {datetime.now()}")
print("=" * 80)

dbutils.notebook.exit("SUCCESS")
