In [None]:
# Databricks notebook source
# =============================================================
# 🚀 UAT MODEL INFERENCE – SMART VERSIONED UC STAGING HANDLER
# =============================================================

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import math
from pyspark.sql import SparkSession
from datetime import datetime
import sys
import os
import warnings

warnings.filterwarnings("ignore")

# =============================================================
# CONFIGURATION
# =============================================================
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
DELTA_INPUT_TABLE = "workspace.default.house_price_delta"

# Metric thresholds
MAPE_THRESHOLD = 15.0
R2_THRESHOLD = 0.75

# =============================================================
# INITIALIZATION
# =============================================================
spark = SparkSession.builder.appName("UAT_Model_Inference_Auto").getOrCreate()
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

print("="*80)
print("🚀 UAT MODEL INFERENCE – SMART VERSIONED UC STAGING HANDLER")
print("="*80)

# =============================================================
# 1️⃣ Auto-detect latest finished experiment
# =============================================================
def get_latest_experiment(client):
    experiments = client.search_experiments(view_type=mlflow.entities.ViewType.ACTIVE_ONLY)
    latest_exp = max(experiments, key=lambda exp: exp.last_update_time)
    print(f"📘 Latest Finished Experiment: {latest_exp.name}")
    return latest_exp

# =============================================================
# 2️⃣ Infer model type and UC model name
# =============================================================
def infer_model_name(exp_name: str):
    exp_lower = exp_name.lower()
    if "xgboost" in exp_lower:
        model_type = "xgboost"
    elif "rf" in exp_lower or "randomforest" in exp_lower:
        model_type = "rf"
    elif "linear" in exp_lower:
        model_type = "linear"
    else:
        model_type = "generic"
    model_name = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_{model_type}_uc"
    print(f"✅ Detected Model Type: {model_type.upper()}")
    print(f"✅ Using Registered Model: {model_name}")
    return model_name, model_type

# =============================================================
# 3️⃣ Load latest staging model from UC
# =============================================================
def load_staging_model(client, model_name):
    model_versions = client.search_model_versions(f"name='{model_name}'")

    # Filter versions that have alias 'staging' (UC-specific)
    staging_versions = []
    for v in model_versions:
        full_version = client.get_model_version(model_name, v.version)
        aliases = [a.lower() for a in full_version.aliases] if full_version.aliases else []
        if "staging" in aliases:
            staging_versions.append(full_version)

    if not staging_versions:
        raise ValueError(f"❌ No model with alias 'staging' found for {model_name}")

    version = max(staging_versions, key=lambda x: int(x.version)).version
    print(f"⏳ Loading model version {version} from alias 'staging'...")
    model_uri = f"models:/{model_name}@staging"
    model = mlflow.pyfunc.load_model(model_uri)
    print(f"✅ Model version {version} loaded successfully.")
    return model, version

# =============================================================
# 4️⃣ Run inference
# =============================================================
def run_inference(model):
    print("🏁 Loading UAT input Delta table for inference...")
    
    # Try to load table directly
    try:
        df_spark = spark.table(DELTA_INPUT_TABLE)
        df = df_spark.toPandas()
    except Exception as e:
        error_msg = str(e)
        if "TABLE_OR_VIEW_NOT_FOUND" in error_msg or "cannot be found" in error_msg:
            raise ValueError(
                f"❌ Delta table '{DELTA_INPUT_TABLE}' does not exist.\n"
                f"   Please create the table first or verify the table name.\n"
                f"   Expected format: catalog.schema.table_name\n"
                f"   Original error: {error_msg}"
            )
        else:
            raise

    if "price" not in df.columns:
        raise ValueError("❌ Input Delta table must contain 'price' column as target.")

    X = df.drop(columns=["price"])
    y_true = df["price"]
    y_pred = model.predict(X)

    print("✅ Inference completed successfully.")
    return df, y_true, y_pred

# =============================================================
# 5️⃣ Evaluate metrics
# =============================================================
def calculate_metrics(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    rmse = math.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
    print(f"\n📊 Evaluation Metrics:")
    print(f"   • MAE  : {mae:.3f}")
    print(f"   • RMSE : {rmse:.3f}")
    print(f"   • R²   : {r2:.3f}")
    print(f"   • MAPE : {mape:.2f}%")
    return mae, rmse, r2, mape

# =============================================================
# 6️⃣ Compare thresholds and log result
# =============================================================
def evaluate_uat_status(mape, r2):
    if mape <= MAPE_THRESHOLD and r2 >= R2_THRESHOLD:
        status = "PASSED"
        print("✅ UAT VALIDATION: PASSED (within thresholds)")
    else:
        status = "FAILED"
        print("❌ UAT VALIDATION: FAILED (outside thresholds)")
    return status

# =============================================================
# 7️⃣ Log results to Delta (avoid creating duplicate tables)
# =============================================================
def log_results_to_delta(model_name, model_version, mae, rmse, r2, mape, status):
    output_table = f"workspace.default.uat_inference_{model_name.split('.')[-1]}"
    result_df = pd.DataFrame([{
        "timestamp": datetime.now(),
        "model_name": model_name,
        "model_version": int(model_version),
        "mae": mae,
        "rmse": rmse,
        "r2": r2,
        "mape": mape,
        "uat_status": status
    }])
    
    # Check if table exists (Spark Connect compatible)
    try:
        existing_df = spark.table(output_table).toPandas()
        # Compare last row metrics to avoid duplicate logging
        if not existing_df.empty:
            last_row = existing_df.iloc[-1]
            if (math.isclose(last_row.mae, mae, rel_tol=1e-6) and
                math.isclose(last_row.rmse, rmse, rel_tol=1e-6) and
                math.isclose(last_row.r2, r2, rel_tol=1e-6) and
                math.isclose(last_row.mape, mape, rel_tol=1e-6)):
                print(f"ℹ️ Metrics unchanged. Skipping write to {output_table}")
                return
    except Exception:
        # Table doesn't exist, will be created on first write
        pass
    
    spark_df = spark.createDataFrame(result_df)
    spark_df.write.mode("append").saveAsTable(output_table)
    print(f"📝 Results logged to Delta table: {output_table}")

# =============================================================
# MAIN EXECUTION
# =============================================================
if __name__ == "__main__":
    try:
        latest_exp = get_latest_experiment(client)
        model_name, model_type = infer_model_name(latest_exp.name)
        model, model_version = load_staging_model(client, model_name)
        df, y_true, y_pred = run_inference(model)
        mae, rmse, r2, mape = calculate_metrics(y_true, y_pred)
        status = evaluate_uat_status(mape, r2)
        log_results_to_delta(model_name, model_version, mae, rmse, r2, mape, status)
        print("\n🎯 UAT process completed successfully!")
    except Exception as e:
        print(f"\n❌ ERROR: {str(e)}")
        sys.exit(1)