In [None]:
from databricks.sdk import WorkspaceClient
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import pandas as pd
import numpy as np
from datetime import datetime
import sys

print("=" * 70)
print("PRODUCTION BATCH INFERENCE VIA SERVING ENDPOINT")
print("=" * 70)

# =============================================================================
# CONFIGURATION
# =============================================================================
ENDPOINT_NAME = "house-price-prediction-prod"

DATA_CATALOG_NAME = "workspace"
DATA_SCHEMA_NAME = "default"
INPUT_TABLE_NAME = "house_price_delta"
OUTPUT_TABLE_NAME = "production_predictions"

FULL_INPUT_TABLE = f"{DATA_CATALOG_NAME}.{DATA_SCHEMA_NAME}.{INPUT_TABLE_NAME}"
FULL_OUTPUT_TABLE = f"{DATA_CATALOG_NAME}.{DATA_SCHEMA_NAME}.{OUTPUT_TABLE_NAME}"

FEATURE_COLUMNS = ['sq_feet', 'num_bedrooms', 'num_bathrooms', 'year_built', 'location_score']
BATCH_SIZE = 100  # Process in batches to avoid API limits

# =============================================================================
# INITIALIZE CLIENTS
# =============================================================================
try:
    w = WorkspaceClient()
    print("Databricks Workspace Client initialized")
    
    spark = SparkSession.builder.appName("BatchInferenceViaEndpoint").getOrCreate()
    print("Spark session initialized")
    
except Exception as e:
    print(f"Error initializing clients: {e}")
    sys.exit(1)

# =============================================================================
# CHECK ENDPOINT STATUS
# =============================================================================
print(f"\nChecking endpoint status: {ENDPOINT_NAME}")

try:
    endpoint = w.serving_endpoints.get(name=ENDPOINT_NAME)
    
    if endpoint.state and endpoint.state.ready:
        ready_status = str(endpoint.state.ready)
        if "READY" in ready_status:
            print(f"  Endpoint is READY")
        else:
            print(f"  Warning: Endpoint status is {ready_status}")
            print(f"  Proceeding anyway...")
    
    # Get model info from endpoint
    if endpoint.config and endpoint.config.served_entities:
        for entity in endpoint.config.served_entities:
            print(f"  Model: {entity.entity_name}")
            print(f"  Version: {entity.entity_version}")
    
except Exception as e:
    print(f"Error: Cannot access endpoint: {e}")
    print(f"Ensure endpoint '{ENDPOINT_NAME}' exists and is ready")
    sys.exit(1)

# =============================================================================
# LOAD INPUT DATA
# =============================================================================
print(f"\nLoading data from: {FULL_INPUT_TABLE}")

try:
    spark_df = spark.read.format("delta").table(FULL_INPUT_TABLE)
    row_count = spark_df.count()
    print(f"  Loaded {row_count} rows")
    
    # Verify columns
    missing_cols = [c for c in FEATURE_COLUMNS if c not in spark_df.columns]
    if missing_cols:
        print(f"Error: Missing columns: {missing_cols}")
        sys.exit(1)
    
    # Select features
    if 'price' in spark_df.columns:
        pandas_df = spark_df.select(*FEATURE_COLUMNS, 'price').toPandas()
    else:
        pandas_df = spark_df.select(*FEATURE_COLUMNS).toPandas()
    
    print(f"  Features: {', '.join(FEATURE_COLUMNS)}")
    
except Exception as e:
    print(f"Error loading data: {e}")
    sys.exit(1)

# =============================================================================
# MAKE PREDICTIONS VIA ENDPOINT
# =============================================================================
print(f"\nMaking predictions via serving endpoint...")
print(f"  Processing {len(pandas_df)} samples in batches of {BATCH_SIZE}")

all_predictions = []

try:
    # Process in batches
    num_batches = (len(pandas_df) + BATCH_SIZE - 1) // BATCH_SIZE
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = min((batch_idx + 1) * BATCH_SIZE, len(pandas_df))
        
        batch_df = pandas_df.iloc[start_idx:end_idx]
        
        # Prepare batch data
        batch_records = batch_df[FEATURE_COLUMNS].to_dict('records')
        
        # Call endpoint
        response = w.serving_endpoints.query(
            name=ENDPOINT_NAME,
            dataframe_records=batch_records
        )
        
        # Extract predictions
        if hasattr(response, 'predictions'):
            predictions = response.predictions
        else:
            # Response might be in different format
            predictions = response
        
        all_predictions.extend(predictions)
        
        if (batch_idx + 1) % 5 == 0 or batch_idx == num_batches - 1:
            print(f"  Processed {end_idx}/{len(pandas_df)} samples...")
    
    # Add predictions to dataframe
    pandas_df['predicted_price'] = all_predictions
    pandas_df['prediction_timestamp'] = datetime.now()
    pandas_df['inference_method'] = 'serving_endpoint'
    pandas_df['endpoint_name'] = ENDPOINT_NAME
    
    print(f"  Predictions completed")
    print(f"  Price Range: {min(all_predictions):,.2f} - {max(all_predictions):,.2f}")
    print(f"  Mean Prediction: {np.mean(all_predictions):,.2f}")
    
except Exception as e:
    print(f"Error during prediction: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# SAVE PREDICTIONS
# =============================================================================
print(f"\nSaving predictions to: {FULL_OUTPUT_TABLE}")

try:
    prediction_date = datetime.now().strftime('%Y-%m-%d')
    prediction_datetime = datetime.now()
    batch_id = f"{prediction_date}_endpoint_{datetime.now().strftime('%H%M%S')}"
    
    pandas_df['prediction_date'] = prediction_date
    pandas_df['batch_id'] = batch_id
    
    print(f"  Batch ID: {batch_id}")
    print(f"  Prediction Date: {prediction_date}")
    
    result_df = spark.createDataFrame(pandas_df)
    
    # Check if table exists
    table_exists = False
    try:
        existing_table = spark.read.format("delta").table(FULL_OUTPUT_TABLE)
        table_exists = True
        existing_count = existing_table.count()
        print(f"  Existing records: {existing_count}")
    except Exception:
        print(f"  Table doesn't exist - will create new")
    
    # Save with append mode
    if table_exists:
        result_df.write \
            .format("delta") \
            .mode("append") \
            .option("mergeSchema", "true") \
            .saveAsTable(FULL_OUTPUT_TABLE)
        print(f"  Mode: APPEND")
    else:
        result_df.write \
            .format("delta") \
            .mode("overwrite") \
            .option("overwriteSchema", "true") \
            .partitionBy("prediction_date") \
            .saveAsTable(FULL_OUTPUT_TABLE)
        print(f"  Mode: CREATE with partitioning")
    
    print(f"  Predictions saved successfully")
    print(f"  Rows written: {len(pandas_df)}")
    
except Exception as e:
    print(f"Error: Save operation failed: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# DISPLAY SAMPLE RESULTS
# =============================================================================
print(f"\n{'=' * 70}")
print("SAMPLE PREDICTIONS (First 10 rows)")
print(f"{'=' * 70}")

sample = pandas_df[[*FEATURE_COLUMNS, 'predicted_price']].head(10).copy()
sample['predicted_price'] = sample['predicted_price'].apply(lambda x: f"{x:,.2f}")

print(sample.to_string(index=False))

# =============================================================================
# PERFORMANCE MONITORING
# =============================================================================
if 'price' in pandas_df.columns:
    print(f"\n{'=' * 70}")
    print("PRODUCTION PERFORMANCE MONITORING")
    print(f"{'=' * 70}")
    
    actual = pandas_df['price']
    pred = pandas_df['predicted_price']
    
    mae = abs(actual - pred).mean()
    mape = (abs(actual - pred) / actual * 100).mean()
    rmse_calc = np.sqrt(((actual - pred) ** 2).mean())
    
    print(f"\nPerformance Metrics:")
    print(f"  MAE: {mae:,.2f}")
    print(f"  MAPE: {mape:.2f}%")
    print(f"  RMSE: {rmse_calc:,.2f}")
    
    if mape > 15.0:
        print(f"\n  WARNING: Model performance degraded!")
        print(f"  MAPE {mape:.2f}% exceeds 15% threshold")
    else:
        print(f"\n  Model performance is within acceptable range")

# =============================================================================
# SUMMARY
# =============================================================================
print(f"\n{'=' * 70}")
print("BATCH INFERENCE COMPLETE (VIA SERVING ENDPOINT)")
print(f"{'=' * 70}")
print(f"Predictions Made: {len(all_predictions)} rows")
print(f"Endpoint: {ENDPOINT_NAME}")
print(f"Output Table: {FULL_OUTPUT_TABLE}")
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"{'=' * 70}")

print(f"\nBENEFITS OF USING SERVING ENDPOINT:")
print(f"  - No sklearn version mismatch issues")
print(f"  - Consistent environment (isolated)")
print(f"  - Same model used by API and batch")
print(f"  - Easier to monitor and debug")

try:
    dbutils.notebook.exit("INFERENCE_SUCCESS")
except:
    pass