In [None]:
# Databricks notebook source
# =============================================================
# ‚úÖ UAT MODEL INFERENCE SCRIPT (FINAL VERSION ‚Äì ALIGNED WITH STAGING LOGIC)
# =============================================================
%pip install xgboost

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import numpy as np
import math
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from pyspark.sql import SparkSession
from datetime import datetime
import warnings
import sys
import os

warnings.filterwarnings("ignore")

# =============================================================
# ‚úÖ CONFIGURATION (FIXED ‚Äî SAME AS YOUR STAGING SCRIPT)
# =============================================================
UC_CATALOG = "workspace"
UC_SCHEMA = "ml"
MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.house_price_xgboost_uc2"

# Delta input table for UAT
DELTA_INPUT_TABLE = "workspace.default.house_price_delta"

# Thresholds for validation
MAPE_THRESHOLD = 15.0   # target < 15%
R2_THRESHOLD   = 0.75   # target > 0.75

# Output table
OUTPUT_TABLE = "workspace.default.uat_inference_house_price_xgboost"


# =============================================================
# ‚úÖ INITIALIZATION
# =============================================================
spark = SparkSession.builder.appName("UAT_Inference_Fixed").getOrCreate()
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

print("="*80)
print("üöÄ UAT MODEL INFERENCE STARTED ‚Äì FIXED VERSION")
print("="*80)


# =============================================================
# ‚úÖ 1Ô∏è‚É£ Load model from STAGING alias (exact match with staging script)
# =============================================================
def load_staging_model(model_name):
    print(f"\nüìå Loading UC model from alias: @Staging")
    try:
        model_uri = f"models:/{model_name}@Staging"
        model = mlflow.pyfunc.load_model(model_uri)

        mv = client.get_model_version_by_alias(model_name, "Staging")

        print(f"‚úÖ Loaded model version: v{mv.version}")
        print(f"‚úÖ Run ID: {mv.run_id}")
        return model, mv.version, mv.run_id

    except Exception as e:
        raise ValueError(f"‚ùå Failed to load model from staging: {e}")


# =============================================================
# ‚úÖ 2Ô∏è‚É£ Load Delta table for inference
# =============================================================
def load_data():
    print("\nüìå Loading UAT Delta input data...")
    try:
        df_spark = spark.table(DELTA_INPUT_TABLE)
        df = df_spark.toPandas()

        if "price" not in df.columns:
            raise ValueError("‚ùå Input table MUST contain 'price' column.")

        X = df.drop(columns=["price"])
        y_true = df["price"]

        print(f"‚úÖ Loaded {len(df)} rows for inference.")
        return df, X, y_true

    except Exception as e:
        raise ValueError(f"‚ùå Failed to load input table: {e}")


# =============================================================
# ‚úÖ 3Ô∏è‚É£ Run inference
# =============================================================
def run_inference(model, X):
    print("\nüìå Running inference...")
    y_pred = model.predict(X)
    print("‚úÖ Inference complete.")
    return y_pred


# =============================================================
# ‚úÖ 4Ô∏è‚É£ Calculate metrics
# =============================================================
def evaluate(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    rmse = math.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100

    print("\nüìä Evaluation Metrics:")
    print(f"MAE  : {mae:.3f}")
    print(f"RMSE : {rmse:.3f}")
    print(f"R¬≤   : {r2:.3f}")
    print(f"MAPE : {mape:.2f}%")
    return mae, rmse, r2, mape


# =============================================================
# ‚úÖ 5Ô∏è‚É£ Threshold validation (UAT pass/fail)
# =============================================================
def validate(mape, r2):
    if mape <= MAPE_THRESHOLD and r2 >= R2_THRESHOLD:
        print("\n‚úÖ UAT PASSED ‚úÖ")
        return "PASSED"
    else:
        print("\n‚ùå UAT FAILED ‚ùå")
        return "FAILED"


# =============================================================
# ‚úÖ 6Ô∏è‚É£ Log results to Delta table (dedupe included)
# =============================================================
def log_results(model_version, mae, rmse, r2, mape, status):
    result_df = pd.DataFrame([{
        "timestamp": datetime.now(),
        "model_version": int(model_version),
        "mae": mae,
        "rmse": rmse,
        "r2": r2,
        "mape": mape,
        "uat_status": status
    }])

    # Prevent duplicate logs
    try:
        existing = spark.table(OUTPUT_TABLE).toPandas()
        if not existing.empty:
            last = existing.iloc[-1]
            if (
                math.isclose(last.mae, mae, rel_tol=1e-6) and
                math.isclose(last.rmse, rmse, rel_tol=1e-6) and
                math.isclose(last.r2, r2, rel_tol=1e-6) and
                math.isclose(last.mape, mape, rel_tol=1e-6)
            ):
                print("\n‚ÑπÔ∏è Metrics unchanged ‚Üí Skipping log")
                return
    except:
        pass

    spark_df = spark.createDataFrame(result_df)
    spark_df.write.mode("append").saveAsTable(OUTPUT_TABLE)

    print(f"\nüìù Logged results to: {OUTPUT_TABLE}")


# =============================================================
# ‚úÖ MAIN EXECUTION FLOW
# =============================================================
try:
    model, model_version, run_id = load_staging_model(MODEL_NAME)
    df, X, y_true = load_data()
    y_pred = run_inference(model, X)
    mae, rmse, r2, mape = evaluate(y_true, y_pred)
    status = validate(mape, r2)
    log_results(model_version, mae, rmse, r2, mape, status)

    print("\nüéØ UAT INFERENCE COMPLETED SUCCESSFULLY")

except Exception as e:
    print(f"\n‚ùå UAT ERROR: {str(e)}")
    sys.exit(1)
