In [None]:
# Databricks notebook source
# =============================================================
# ‚úÖ UAT MODEL INFERENCE SCRIPT (ALIGNED WITH REGISTRATION & STAGING)
# =============================================================
# COMMAND ----------
%pip install xgboost

# COMMAND ----------
import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import numpy as np
import math
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from pyspark.sql import SparkSession
from datetime import datetime
import warnings
import sys
import os

warnings.filterwarnings("ignore")

# =============================================================
# ‚úÖ CONFIGURATION (ALIGNED WITH REGISTRATION & STAGING SCRIPTS)
# =============================================================
UC_CATALOG = "workspace"
UC_SCHEMA = "ml"
MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.house_price_xgboost_uc2"
STAGING_ALIAS = "Staging"

# Delta input table for UAT inference
DELTA_INPUT_TABLE = "workspace.default.house_price_delta"

# Feature columns (must match training script)
FEATURE_COLS = ['sq_feet', 'num_bedrooms', 'num_bathrooms', 'year_built', 'location_score']
LABEL_COL = 'price'

# Thresholds for validation
MAPE_THRESHOLD = 15.0   # target < 15%
R2_THRESHOLD   = 0.75   # target > 0.75

# Output table for UAT results
OUTPUT_TABLE = "workspace.default.uat_inference_house_price_xgboost"


# =============================================================
# ‚úÖ INITIALIZATION
# =============================================================
print("="*80)
print("üöÄ UAT MODEL INFERENCE - ALIGNED VERSION")
print("="*80)

spark = SparkSession.builder.appName("UAT_Inference_Aligned").getOrCreate()
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

print(f"\nüìã Configuration:")
print(f"   Model: {MODEL_NAME}")
print(f"   Alias: {STAGING_ALIAS}")
print(f"   Input Table: {DELTA_INPUT_TABLE}")
print(f"   Output Table: {OUTPUT_TABLE}")
print(f"   Feature Columns: {FEATURE_COLS}")


# =============================================================
# ‚úÖ 1Ô∏è‚É£ Load model from STAGING alias
# =============================================================
def load_staging_model(model_name, alias):
    """
    Load model from Unity Catalog using alias (aligned with staging script)
    """
    print(f"\n{'='*70}")
    print(f"üìã STEP 1: Loading Model from @{alias}")
    print(f"{'='*70}")
    
    try:
        model_uri = f"models:/{model_name}@{alias}"
        print(f"   Model URI: {model_uri}")
        
        model = mlflow.pyfunc.load_model(model_uri)
        mv = client.get_model_version_by_alias(model_name, alias)

        print(f"\n‚úÖ Model Loaded Successfully!")
        print(f"   Version: v{mv.version}")
        print(f"   Run ID: {mv.run_id}")
        print(f"   Status: {mv.status}")
        
        # Get metric from tags if available
        metric_tag = mv.tags.get("metric_rmse", "N/A")
        print(f"   Training RMSE: {metric_tag}")
        
        return model, mv.version, mv.run_id

    except Exception as e:
        print(f"\n‚ùå Failed to load model from {alias}: {e}")
        import traceback
        traceback.print_exc()
        raise ValueError(f"Model loading failed: {e}")


# =============================================================
# ‚úÖ 2Ô∏è‚É£ Load Delta table for inference
# =============================================================
def load_data():
    """
    Load UAT data from Delta table with proper feature selection
    """
    print(f"\n{'='*70}")
    print("üìã STEP 2: Loading UAT Data")
    print(f"{'='*70}")
    
    try:
        df_spark = spark.table(DELTA_INPUT_TABLE)
        df = df_spark.toPandas()

        print(f"   Total rows loaded: {len(df)}")
        print(f"   Columns: {list(df.columns)}")

        # Validate required columns exist
        missing_features = [col for col in FEATURE_COLS if col not in df.columns]
        if missing_features:
            raise ValueError(f"Missing feature columns: {missing_features}")

        if LABEL_COL not in df.columns:
            raise ValueError(f"Missing label column: {LABEL_COL}")

        # Select only required features and label
        X = df[FEATURE_COLS]
        y_true = df[LABEL_COL]

        print(f"\n‚úÖ Data Loaded Successfully!")
        print(f"   Features shape: {X.shape}")
        print(f"   Labels shape: {y_true.shape}")
        
        return df, X, y_true

    except Exception as e:
        print(f"\n‚ùå Failed to load input table: {e}")
        import traceback
        traceback.print_exc()
        raise ValueError(f"Data loading failed: {e}")


# =============================================================
# ‚úÖ 3Ô∏è‚É£ Run inference
# =============================================================
def run_inference(model, X):
    """
    Run model inference on UAT data
    """
    print(f"\n{'='*70}")
    print("üìã STEP 3: Running Inference")
    print(f"{'='*70}")
    
    try:
        print(f"   Running predictions on {len(X)} samples...")
        y_pred = model.predict(X)
        
        print(f"\n‚úÖ Inference Complete!")
        print(f"   Predictions generated: {len(y_pred)}")
        print(f"   Sample predictions: {y_pred[:5]}")
        
        return y_pred
        
    except Exception as e:
        print(f"\n‚ùå Inference failed: {e}")
        import traceback
        traceback.print_exc()
        raise


# =============================================================
# ‚úÖ 4Ô∏è‚É£ Calculate metrics
# =============================================================
def evaluate(y_true, y_pred):
    """
    Calculate evaluation metrics for UAT
    """
    print(f"\n{'='*70}")
    print("üìã STEP 4: Evaluating Model Performance")
    print(f"{'='*70}")
    
    try:
        mae = mean_absolute_error(y_true, y_pred)
        rmse = math.sqrt(mean_squared_error(y_true, y_pred))
        r2 = r2_score(y_true, y_pred)
        mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100

        print(f"\nüìä Evaluation Metrics:")
        print(f"   MAE  : {mae:.3f}")
        print(f"   RMSE : {rmse:.3f}")
        print(f"   R¬≤   : {r2:.3f}")
        print(f"   MAPE : {mape:.2f}%")
        
        return mae, rmse, r2, mape
        
    except Exception as e:
        print(f"\n‚ùå Evaluation failed: {e}")
        raise


# =============================================================
# ‚úÖ 5Ô∏è‚É£ Threshold validation (UAT pass/fail)
# =============================================================
def validate(mape, r2):
    """
    Validate model performance against UAT thresholds
    """
    print(f"\n{'='*70}")
    print("üìã STEP 5: UAT Validation")
    print(f"{'='*70}")
    
    print(f"\nüìè Validation Thresholds:")
    print(f"   MAPE threshold: ‚â§ {MAPE_THRESHOLD}%")
    print(f"   R¬≤ threshold:   ‚â• {R2_THRESHOLD}")
    
    print(f"\nüìä Actual Performance:")
    print(f"   MAPE: {mape:.2f}% {'‚úÖ' if mape <= MAPE_THRESHOLD else '‚ùå'}")
    print(f"   R¬≤:   {r2:.3f}  {'‚úÖ' if r2 >= R2_THRESHOLD else '‚ùå'}")
    
    if mape <= MAPE_THRESHOLD and r2 >= R2_THRESHOLD:
        print(f"\n{'='*70}")
        print("‚úÖ‚úÖ UAT PASSED ‚úÖ‚úÖ")
        print(f"{'='*70}")
        return "PASSED"
    else:
        print(f"\n{'='*70}")
        print("‚ùå‚ùå UAT FAILED ‚ùå‚ùå")
        print(f"{'='*70}")
        
        # Show which criteria failed
        if mape > MAPE_THRESHOLD:
            print(f"   ‚ö†Ô∏è MAPE too high: {mape:.2f}% > {MAPE_THRESHOLD}%")
        if r2 < R2_THRESHOLD:
            print(f"   ‚ö†Ô∏è R¬≤ too low: {r2:.3f} < {R2_THRESHOLD}")
        
        return "FAILED"


# =============================================================
# ‚úÖ 6Ô∏è‚É£ Log results to Delta table (with deduplication)
# =============================================================
def log_results(model_version, run_id, mae, rmse, r2, mape, status):
    """
    Log UAT results to Delta table with duplicate prevention
    """
    print(f"\n{'='*70}")
    print("üìã STEP 6: Logging Results")
    print(f"{'='*70}")
    
    try:
        result_df = pd.DataFrame([{
            "timestamp": datetime.now(),
            "model_version": int(model_version),
            "run_id": run_id,
            "mae": float(mae),
            "rmse": float(rmse),
            "r2": float(r2),
            "mape": float(mape),
            "uat_status": status
        }])

        # Check for duplicates
        try:
            existing = spark.table(OUTPUT_TABLE).toPandas()
            if not existing.empty:
                last = existing.iloc[-1]
                
                # Check if metrics are identical to last run
                is_duplicate = (
                    int(last.model_version) == int(model_version) and
                    math.isclose(float(last.mae), mae, rel_tol=1e-6) and
                    math.isclose(float(last.rmse), rmse, rel_tol=1e-6) and
                    math.isclose(float(last.r2), r2, rel_tol=1e-6) and
                    math.isclose(float(last.mape), mape, rel_tol=1e-6)
                )
                
                if is_duplicate:
                    print("\n‚ÑπÔ∏è Duplicate Entry Detected")
                    print("   Metrics unchanged from last run ‚Üí Skipping log")
                    return
        except Exception as e:
            print(f"   Note: Could not check for duplicates (table may not exist): {e}")

        # Write to Delta table
        spark_df = spark.createDataFrame(result_df)
        spark_df.write.mode("append").saveAsTable(OUTPUT_TABLE)

        print(f"\n‚úÖ Results Logged Successfully!")
        print(f"   Output Table: {OUTPUT_TABLE}")
        print(f"   Model Version: v{model_version}")
        print(f"   UAT Status: {status}")

    except Exception as e:
        print(f"\n‚ùå Failed to log results: {e}")
        import traceback
        traceback.print_exc()
        raise


# =============================================================
# ‚úÖ MAIN EXECUTION FLOW
# =============================================================
def main():
    """
    Main execution flow for UAT inference
    """
    try:
        print("\n" + "="*80)
        print("üé¨ STARTING UAT INFERENCE PIPELINE")
        print("="*80)
        
        # Step 1: Load model
        model, model_version, run_id = load_staging_model(MODEL_NAME, STAGING_ALIAS)
        
        # Step 2: Load data
        df, X, y_true = load_data()
        
        # Step 3: Run inference
        y_pred = run_inference(model, X)
        
        # Step 4: Evaluate
        mae, rmse, r2, mape = evaluate(y_true, y_pred)
        
        # Step 5: Validate
        status = validate(mape, r2)
        
        # Step 6: Log results
        log_results(model_version, run_id, mae, rmse, r2, mape, status)

        print("\n" + "="*80)
        print("‚ú® UAT INFERENCE COMPLETED SUCCESSFULLY ‚ú®")
        print("="*80)
        print(f"\nüìä Summary:")
        print(f"   Model Version: v{model_version}")
        print(f"   UAT Status: {status}")
        print(f"   RMSE: {rmse:.3f}")
        print(f"   MAPE: {mape:.2f}%")
        print(f"   R¬≤: {r2:.3f}")
        print("="*80 + "\n")

    except Exception as e:
        print("\n" + "="*80)
        print("‚ùå UAT INFERENCE FAILED")
        print("="*80)
        print(f"Error: {str(e)}")
        print("="*80 + "\n")
        sys.exit(1)


# =============================================================
# ‚úÖ EXECUTE
# =============================================================
if __name__ == "__main__":
    main()