In [None]:
# Databricks notebook source
import mlflow
from mlflow.tracking import MlflowClient
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import sys
import os
import warnings
import re
import json
import hashlib
from datetime import datetime

# Suppress warnings
warnings.filterwarnings("ignore")
os.environ['PYTHONWARNINGS'] = 'ignore'

print("=" * 70)
print("UAT MODEL INFERENCE - STAGING VALIDATION")
print("=" * 70)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
MODEL_NAME = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_model_uc"

DATA_CATALOG_NAME = "workspace"
DATA_SCHEMA_NAME = "default"
TABLE_NAME = "house_price_delta"
FULL_TABLE_NAME = f"{DATA_CATALOG_NAME}.{DATA_SCHEMA_NAME}.{TABLE_NAME}"

EXPERIMENT_NAME = "/Shared/House_Price_Prediction_Delta_RF"

FEATURE_COLUMNS = ['sq_feet', 'num_bedrooms', 'num_bathrooms', 'year_built', 'location_score']
LABEL_COLUMN = 'price'

MAX_ACCEPTABLE_MAPE = 15.0
MIN_ACCEPTABLE_R2 = 0.75

# =============================================================================
# SPARK SESSION
# =============================================================================
try:
    spark = SparkSession.builder.appName("UAT_ModelInference").getOrCreate()
    print("Spark session initialized")
except Exception as e:
    print(f"Error initializing Spark: {e}")
    sys.exit(1)

# =============================================================================
# GET MODEL ALIAS FROM WIDGET
# =============================================================================
try:
    model_alias = dbutils.widgets.get("alias")
    print(f"Model Alias from widget: {model_alias}")
except:
    model_alias = "Staging"
    print(f"Widget not found, using default: {model_alias}")

# =============================================================================
# LOAD MODEL
# =============================================================================
print(f"\nLoading model for UAT validation...")
print(f"Target Alias: {model_alias}")

try:
    client = MlflowClient()
    experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
    if not experiment:
        raise Exception(f"Experiment '{EXPERIMENT_NAME}' not found")

    runs = client.search_runs(
        experiment_ids=[experiment.experiment_id],
        filter_string="status = 'FINISHED'",
        order_by=["start_time DESC"],
        max_results=1
    )

    if not runs:
        raise Exception("No successful runs found. Please train a model first.")

    latest_run = runs[0]
    run_id = latest_run.info.run_id

    print(f"\nModel Details:\n  Run ID: {run_id}\n  Run Name: {latest_run.info.run_name}")

    print("\n  Training Parameters:")
    for k, v in latest_run.data.params.items():
        print(f"    {k}: {v}")

    print("\n  Training Metrics:")
    training_metrics = {k: v for k, v in latest_run.data.metrics.items()}
    for k, v in training_metrics.items():
        print(f"    {k}: {v:.4f}")

    model_uri = f"runs:/{run_id}/sklearn_rf_model"
    model = mlflow.sklearn.load_model(model_uri)
    print("\nModel loaded successfully for UAT validation")

except Exception as e:
    print(f"\nError loading model: {e}")
    import traceback; traceback.print_exc()
    sys.exit(1)

# =============================================================================
# LOAD TEST DATA
# =============================================================================
print(f"\nLoading test data from: {FULL_TABLE_NAME}")

try:
    spark_df = spark.read.format("delta").table(FULL_TABLE_NAME)
    print(f"Data loaded: {spark_df.count()} rows")

    available_cols = spark_df.columns
    missing_cols = [c for c in FEATURE_COLUMNS + [LABEL_COLUMN] if c not in available_cols]
    if missing_cols:
        raise Exception(f"Missing columns: {missing_cols}")

    pandas_df = spark_df.select(*FEATURE_COLUMNS, LABEL_COLUMN).toPandas()
    print(f"Converted to Pandas: {pandas_df.shape}")

except Exception as e:
    print(f"\nError loading data: {e}")
    sys.exit(1)

# =============================================================================
# PREDICTIONS
# =============================================================================
print(f"\n{'='*70}\nRUNNING INFERENCE ON UAT DATA\n{'='*70}")

try:
    X_test = pandas_df[FEATURE_COLUMNS]
    y_actual = pandas_df[LABEL_COLUMN]
    y_pred = model.predict(X_test)

    pandas_df['predicted_price'] = y_pred
    pandas_df['prediction_error'] = y_actual - y_pred
    pandas_df['absolute_error'] = abs(pandas_df['prediction_error'])
    pandas_df['percentage_error'] = (pandas_df['absolute_error'] / y_actual) * 100
    print(f"\nPredictions completed: {len(y_pred)} samples")

except Exception as e:
    print(f"\nError during prediction: {e}")
    import traceback; traceback.print_exc()
    sys.exit(1)

# =============================================================================
# METRIC CALCULATION
# =============================================================================
print(f"\n{'='*70}\nUAT VALIDATION METRICS\n{'='*70}")

try:
    mae = mean_absolute_error(y_actual, y_pred)
    rmse = np.sqrt(mean_squared_error(y_actual, y_pred))
    r2 = r2_score(y_actual, y_pred)
    mape = (abs(y_actual - y_pred) / y_actual * 100).mean()

    # Handle NaN/Inf safely
    if np.isnan(rmse) or np.isinf(rmse): rmse = 999999.99
    if np.isnan(r2): r2 = 0.0
    if np.isnan(mape) or np.isinf(mape): mape = 100.0

    print(f"\nRegression Metrics:\n  MAE: ${mae:,.2f}\n  RMSE: ${rmse:,.2f}\n  R¬≤: {r2:.4f}\n  MAPE: {mape:.2f}%")

except Exception as e:
    print(f"\nError calculating metrics: {e}")
    sys.exit(1)

# =============================================================================
# LOG METRICS TO MLFLOW  ‚úÖ (Fix)
# =============================================================================
print(f"\nLogging UAT metrics to MLflow for promotion tracking...")

try:
    with mlflow.start_run(run_id=run_id):
        mlflow.log_metric("uat_mae", mae)
        mlflow.log_metric("uat_rmse", rmse)
        mlflow.log_metric("uat_r2_score", r2)
        mlflow.log_metric("uat_mape", mape)

        # For Promotion Script compatibility
        mlflow.log_metric("test_rmse", rmse)
        mlflow.log_metric("test_r2_score", r2)
        mlflow.log_metric("best_cv_rmse", training_metrics.get("best_cv_rmse", 0.0))

    print("‚úÖ UAT metrics successfully logged to MLflow.")
except Exception as e:
    print(f"‚ö†Ô∏è Could not log UAT metrics to MLflow: {e}")

# =============================================================================
# VALIDATION PASS/FAIL
# =============================================================================
print(f"\n{'='*70}\nUAT VALIDATION RESULTS\n{'='*70}")

validation_passed = True
results = []

if mape <= MAX_ACCEPTABLE_MAPE:
    results.append(f"PASS: MAPE {mape:.2f}% <= {MAX_ACCEPTABLE_MAPE}%")
else:
    results.append(f"FAIL: MAPE {mape:.2f}% > {MAX_ACCEPTABLE_MAPE}%")
    validation_passed = False

if r2 >= MIN_ACCEPTABLE_R2:
    results.append(f"PASS: R¬≤ {r2:.4f} >= {MIN_ACCEPTABLE_R2}")
else:
    results.append(f"FAIL: R¬≤ {r2:.4f} < {MIN_ACCEPTABLE_R2}")
    validation_passed = False

for res in results:
    print("  " + res)

if validation_passed:
    print("\n‚úÖ UAT VALIDATION: PASSED ‚Äì Model ready for Production.")
else:
    print("\nüö´ UAT VALIDATION: FAILED ‚Äì Retrain with better parameters/data.")

# =============================================================================
# SAVE RESULTS (UNCHANGED LOGIC)
# =============================================================================
# Your existing save logic remains unchanged here.
# [ ... Keep your fingerprint and delta save logic block as it is ... ]

# =============================================================================
# EXIT CODE
# =============================================================================
print(f"\n{'='*70}\nUAT INFERENCE COMPLETE\n{'='*70}")
if not validation_passed:
    try:
        dbutils.notebook.exit("FAILED")
    except:
        raise Exception("UAT Validation Failed: Model does not meet quality thresholds")
else:
    try:
        dbutils.notebook.exit("PASSED")
    except:
        pass
