In [None]:
# Databricks notebook source
import mlflow
from mlflow.tracking import MlflowClient
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import sys
import os
import warnings
import re
import json
import hashlib
from datetime import datetime

# Suppress warnings
warnings.filterwarnings("ignore")
os.environ['PYTHONWARNINGS'] = 'ignore'

print("=" * 70)
print("UAT MODEL INFERENCE - STAGING VALIDATION")
print("=" * 70)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
MODEL_NAME = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_model_uc"

DATA_CATALOG_NAME = "workspace"
DATA_SCHEMA_NAME = "default"
TABLE_NAME = "house_price_delta"
FULL_TABLE_NAME = f"{DATA_CATALOG_NAME}.{DATA_SCHEMA_NAME}.{TABLE_NAME}"

EXPERIMENT_NAME = "/Shared/House_Price_Prediction_Delta_RF"

FEATURE_COLUMNS = ['sq_feet', 'num_bedrooms', 'num_bathrooms', 'year_built', 'location_score']
LABEL_COLUMN = 'price'

MAX_ACCEPTABLE_MAPE = 15.0
MIN_ACCEPTABLE_R2 = 0.75

STAGING_ALIAS = "staging"

# =============================================================================
# SPARK SESSION
# =============================================================================
try:
    spark = SparkSession.builder.appName("UAT_ModelInference").getOrCreate()
    print("✓ Spark session initialized")
except Exception as e:
    print(f"❌ Error initializing Spark: {e}")
    sys.exit(1)

# =============================================================================
# MLFLOW SETUP
# =============================================================================
try:
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
    
    mlflow.set_experiment(EXPERIMENT_NAME)
    client = MlflowClient()
    print("✓ MLflow configured")
except Exception as e:
    print(f"❌ Error setting up MLflow: {e}")
    sys.exit(1)

# =============================================================================
# GET MODEL ALIAS FROM WIDGET
# =============================================================================
try:
    model_alias = dbutils.widgets.get("alias")
    print(f"Model Alias from widget: {model_alias}")
except:
    model_alias = STAGING_ALIAS
    print(f"Widget not found, using default: {model_alias}")

# =============================================================================
# LOAD LATEST TRAINED MODEL & GET TRAINING METRICS
# =============================================================================
print(f"\n{'='*70}")
print("LOADING MODEL FOR UAT VALIDATION")
print(f"{'='*70}")

try:
    experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
    if not experiment:
        raise Exception(f"Experiment '{EXPERIMENT_NAME}' not found")

    # Get latest successful training run
    runs = client.search_runs(
        experiment_ids=[experiment.experiment_id],
        filter_string="status = 'FINISHED'",
        order_by=["start_time DESC"],
        max_results=1
    )

    if not runs:
        raise Exception("No successful runs found. Please train a model first.")

    training_run = runs[0]
    training_run_id = training_run.info.run_id

    print(f"\n✓ Found Latest Training Run:")
    print(f"  Run ID: {training_run_id}")
    print(f"  Run Name: {training_run.info.run_name}")

    # Get training metrics
    training_metrics = training_run.data.metrics
    training_params = training_run.data.params

    print("\n  Training Parameters:")
    for k, v in training_params.items():
        print(f"    {k}: {v}")

    print("\n  Training Metrics:")
    for k, v in training_metrics.items():
        print(f"    {k}: {v:.4f}")

    # Load the model
    model_uri = f"runs:/{training_run_id}/sklearn_rf_model"
    model = mlflow.sklearn.load_model(model_uri)
    print("\n✓ Model loaded successfully")

except Exception as e:
    print(f"\n❌ Error loading model: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# LOAD TEST DATA
# =============================================================================
print(f"\n{'='*70}")
print("LOADING TEST DATA")
print(f"{'='*70}")
print(f"Table: {FULL_TABLE_NAME}")

try:
    spark_df = spark.read.format("delta").table(FULL_TABLE_NAME)
    total_rows = spark_df.count()
    print(f"✓ Data loaded: {total_rows} rows")

    available_cols = spark_df.columns
    missing_cols = [c for c in FEATURE_COLUMNS + [LABEL_COLUMN] if c not in available_cols]
    if missing_cols:
        raise Exception(f"Missing columns: {missing_cols}")

    pandas_df = spark_df.select(*FEATURE_COLUMNS, LABEL_COLUMN).toPandas()
    print(f"✓ Converted to Pandas: {pandas_df.shape}")

except Exception as e:
    print(f"\n❌ Error loading data: {e}")
    sys.exit(1)

# =============================================================================
# RUN PREDICTIONS
# =============================================================================
print(f"\n{'='*70}")
print("RUNNING UAT INFERENCE")
print(f"{'='*70}")

try:
    X_test = pandas_df[FEATURE_COLUMNS]
    y_actual = pandas_df[LABEL_COLUMN]
    
    y_pred = model.predict(X_test)
    
    pandas_df['predicted_price'] = y_pred
    pandas_df['prediction_error'] = y_actual - y_pred
    pandas_df['absolute_error'] = abs(pandas_df['prediction_error'])
    pandas_df['percentage_error'] = (pandas_df['absolute_error'] / y_actual) * 100
    
    print(f"✓ Predictions completed: {len(y_pred)} samples")

except Exception as e:
    print(f"\n❌ Error during prediction: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# CALCULATE UAT METRICS
# =============================================================================
print(f"\n{'='*70}")
print("UAT VALIDATION METRICS")
print(f"{'='*70}")

try:
    mae = mean_absolute_error(y_actual, y_pred)
    rmse = np.sqrt(mean_squared_error(y_actual, y_pred))
    r2 = r2_score(y_actual, y_pred)
    mape = (abs(y_actual - y_pred) / y_actual * 100).mean()

    # Handle NaN/Inf safely
    if np.isnan(rmse) or np.isinf(rmse): 
        rmse = 999999.99
    if np.isnan(r2): 
        r2 = 0.0
    if np.isnan(mape) or np.isinf(mape): 
        mape = 100.0

    print(f"\nUAT Metrics:")
    print(f"  MAE: ${mae:,.2f}")
    print(f"  RMSE: ${rmse:,.2f}")
    print(f"  R²: {r2:.4f}")
    print(f"  MAPE: {mape:.2f}%")

except Exception as e:
    print(f"\n❌ Error calculating metrics: {e}")
    sys.exit(1)

# =============================================================================
# ✅ CREATE NEW UAT RUN AND LOG METRICS (FIXED!)
# =============================================================================
print(f"\n{'='*70}")
print("LOGGING UAT METRICS TO MLFLOW")
print(f"{'='*70}")

try:
    # ✅ Create a NEW run for UAT validation (don't reuse training run)
    with mlflow.start_run(run_name=f"UAT_Validation_{model_alias}") as uat_run:
        uat_run_id = uat_run.info.run_id
        
        # Log UAT metrics
        mlflow.log_metric("uat_mae", mae)
        mlflow.log_metric("uat_rmse", rmse)
        mlflow.log_metric("uat_r2_score", r2)
        mlflow.log_metric("uat_mape", mape)
        
        # ✅ CRITICAL: Log metrics that promotion script expects
        mlflow.log_metric("test_rmse", rmse)
        mlflow.log_metric("test_r2_score", r2)
        
        # Log training metrics for reference
        best_cv_rmse = training_metrics.get("best_cv_rmse", 0.0)
        mlflow.log_metric("best_cv_rmse", best_cv_rmse)
        
        # Log training params for traceability
        for param_key, param_value in training_params.items():
            mlflow.log_param(param_key, param_value)
        
        # Link to original training run
        mlflow.set_tag("training_run_id", training_run_id)
        mlflow.set_tag("validation_stage", "UAT")
        mlflow.set_tag("model_alias", model_alias)
        
        print(f"\n✓ UAT Run Created:")
        print(f"  UAT Run ID: {uat_run_id}")
        print(f"  Linked Training Run: {training_run_id}")
    
    print("✓ UAT metrics logged successfully")

except Exception as e:
    print(f"⚠ Warning: Could not log UAT metrics: {e}")
    import traceback
    traceback.print_exc()

# =============================================================================
# ✅ REGISTER MODEL FROM UAT RUN WITH METRICS (CRITICAL FIX!)
# =============================================================================
print(f"\n{'='*70}")
print("REGISTERING MODEL WITH UAT METRICS")
print(f"{'='*70}")

try:
    # ✅ IMPORTANT: Log model in UAT run so it's linked to UAT metrics
    with mlflow.start_run(run_id=uat_run_id):
        model_info = mlflow.sklearn.log_model(
            model, 
            "model",
            registered_model_name=MODEL_NAME
        )
    
    print(f"✓ Model logged in UAT run: {uat_run_id}")
    
    # Get the newly created version - it will have UAT run_id
    import time
    time.sleep(3)  # Wait for registration to complete
    
    # Find the version that was just registered from our UAT run
    model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")
    
    # Get version registered from our UAT run
    uat_version = None
    for v in model_versions:
        if v.run_id == uat_run_id:
            uat_version = v
            break
    
    if not uat_version:
        # Fallback: get latest version
        uat_version = max(model_versions, key=lambda v: int(v.version))
        print(f"⚠ Warning: Could not find version by UAT run_id, using latest version")
    
    print(f"✓ Model registered as version: {uat_version.version}")
    print(f"✓ Linked to UAT Run ID: {uat_version.run_id}")
    
    # ✅ Remove old staging alias from previous version (if exists)
    for v in model_versions:
        version_detail = client.get_model_version(MODEL_NAME, v.version)
        if STAGING_ALIAS in version_detail.aliases and v.version != uat_version.version:
            print(f"ℹ Removing '{STAGING_ALIAS}' alias from old version {v.version}")
            client.delete_registered_model_alias(MODEL_NAME, STAGING_ALIAS)
            break
    
    # ✅ Set staging alias on new version
    client.set_registered_model_alias(MODEL_NAME, STAGING_ALIAS, uat_version.version)
    print(f"✓ Alias '{STAGING_ALIAS}' set on version {uat_version.version}")
    
    # Verify the alias was set correctly
    time.sleep(1)
    verified_version = client.get_model_version(MODEL_NAME, uat_version.version)
    
    print(f"\n{'='*70}")
    print("✅ MODEL REGISTRATION SUCCESSFUL")
    print(f"{'='*70}")
    print(f"Model Details:")
    print(f"  Name: {MODEL_NAME}")
    print(f"  Version: {uat_version.version}")
    print(f"  Alias: {STAGING_ALIAS}")
    print(f"  UAT Run ID: {uat_version.run_id}")
    print(f"  Status: {verified_version.status}")
    print(f"\nMetrics Available in Run:")
    print(f"  test_r2_score: {r2:.4f}")
    print(f"  test_rmse: ${rmse:,.2f}")
    print(f"  best_cv_rmse: ${best_cv_rmse:,.2f}")
    print(f"{'='*70}")

except Exception as e:
    print(f"❌ Error during model registration: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# VALIDATION PASS/FAIL
# =============================================================================
print(f"\n{'='*70}")
print("UAT VALIDATION RESULTS")
print(f"{'='*70}")

validation_passed = True
results = []

if mape <= MAX_ACCEPTABLE_MAPE:
    results.append(f"✓ PASS: MAPE {mape:.2f}% <= {MAX_ACCEPTABLE_MAPE}%")
else:
    results.append(f"✗ FAIL: MAPE {mape:.2f}% > {MAX_ACCEPTABLE_MAPE}%")
    validation_passed = False

if r2 >= MIN_ACCEPTABLE_R2:
    results.append(f"✓ PASS: R² {r2:.4f} >= {MIN_ACCEPTABLE_R2}")
else:
    results.append(f"✗ FAIL: R² {r2:.4f} < {MIN_ACCEPTABLE_R2}")
    validation_passed = False

for res in results:
    print(f"  {res}")

if validation_passed:
    print(f"\n{'='*70}")
    print("✅ UAT VALIDATION: PASSED")
    print(f"{'='*70}")
    print("Model is ready for Production promotion")
else:
    print(f"\n{'='*70}")
    print("❌ UAT VALIDATION: FAILED")
    print(f"{'='*70}")
    print("Model does not meet quality thresholds")
    print("Please retrain with better parameters/data")

# =============================================================================
# PREDICTION SAMPLE
# =============================================================================
print(f"\n{'='*70}")
print("SAMPLE PREDICTIONS")
print(f"{'='*70}")

sample_df = pandas_df.head(10)[['price', 'predicted_price', 'prediction_error', 'percentage_error']]
print(sample_df.to_string(index=False))

# =============================================================================
# ✅ FINGERPRINT LOGIC: Check if model is different from previous run
# =============================================================================
print(f"\n{'='*70}")
print("MODEL FINGERPRINT VALIDATION")
print(f"{'='*70}")

def calculate_model_fingerprint(params_dict, metrics_dict):
    """Create a hash fingerprint from model params and key metrics"""
    fingerprint_data = {
        'params': {k: v for k, v in sorted(params_dict.items())},
        'metrics': {
            'r2': round(metrics_dict.get('test_r2_score', 0), 4),
            'rmse': round(metrics_dict.get('test_rmse', 0), 2)
        }
    }
    fingerprint_str = json.dumps(fingerprint_data, sort_keys=True)
    return hashlib.md5(fingerprint_str.encode()).hexdigest()

# Calculate current model fingerprint
current_fingerprint = calculate_model_fingerprint(
    training_params,
    {'test_r2_score': r2, 'test_rmse': rmse}
)

print(f"Current Model Fingerprint: {current_fingerprint}")

# Check if previous results exist
RESULTS_TABLE = f"{DATA_CATALOG_NAME}.{DATA_SCHEMA_NAME}.uat_inference_results"
previous_fingerprint = None
is_new_model = True

try:
    # Try to read existing results table
    existing_results = spark.read.format("delta").table(RESULTS_TABLE)
    
    if existing_results.count() > 0:
        # Get the most recent fingerprint
        latest_result = existing_results.orderBy(col("inference_timestamp").desc()).first()
        previous_fingerprint = latest_result['model_fingerprint']
        
        print(f"Previous Model Fingerprint: {previous_fingerprint}")
        
        if current_fingerprint == previous_fingerprint:
            is_new_model = False
            print("✓ Model fingerprint MATCHES previous run - Will UPDATE existing table")
        else:
            is_new_model = True
            print("✓ Model fingerprint DIFFERENT - Will CREATE new table version")
    else:
        print("ℹ No previous results found - Will CREATE new table")
        is_new_model = True
        
except Exception as e:
    print(f"ℹ Results table doesn't exist yet - Will CREATE new table")
    is_new_model = True

# =============================================================================
# ✅ SAVE RESULTS TO DELTA TABLE (Based on Fingerprint)
# =============================================================================
print(f"\n{'='*70}")
print("SAVING UAT RESULTS TO DELTA TABLE")
print(f"{'='*70}")

try:
    # Prepare results dataframe
    results_df = pandas_df.copy()
    results_df['inference_timestamp'] = datetime.now()
    results_df['model_version'] = latest_version.version
    results_df['model_run_id'] = uat_run_id
    results_df['model_fingerprint'] = current_fingerprint
    results_df['uat_r2_score'] = r2
    results_df['uat_rmse'] = rmse
    results_df['uat_mae'] = mae
    results_df['uat_mape'] = mape
    results_df['validation_passed'] = validation_passed
    
    # Convert to Spark DataFrame
    spark_results_df = spark.createDataFrame(results_df)
    
    if is_new_model:
        # ✅ NEW MODEL: Overwrite or create table
        print("Action: OVERWRITE - Creating new table version")
        spark_results_df.write \
            .format("delta") \
            .mode("overwrite") \
            .option("overwriteSchema", "true") \
            .saveAsTable(RESULTS_TABLE)
        print(f"✓ New results saved to: {RESULTS_TABLE}")
        
    else:
        # ✅ SAME MODEL: Append to existing table
        print("Action: APPEND - Updating existing table")
        spark_results_df.write \
            .format("delta") \
            .mode("append") \
            .saveAsTable(RESULTS_TABLE)
        print(f"✓ Results appended to: {RESULTS_TABLE}")
    
    # Show table info
    result_count = spark.read.format("delta").table(RESULTS_TABLE).count()
    print(f"\nTable Statistics:")
    print(f"  Table: {RESULTS_TABLE}")
    print(f"  Total Rows: {result_count}")
    print(f"  Model Fingerprint: {current_fingerprint}")
    print(f"  Strategy: {'NEW VERSION (Overwrite)' if is_new_model else 'UPDATE (Append)'}")

except Exception as e:
    print(f"⚠ Warning: Could not save results to Delta table: {e}")
    import traceback
    traceback.print_exc()

# =============================================================================
# EXIT
# =============================================================================
print(f"\n{'='*70}")
print("UAT INFERENCE COMPLETE")
print(f"{'='*70}")

exit_status = "PASSED" if validation_passed else "FAILED"
print(f"Status: {exit_status}")
print(f"Model Fingerprint: {current_fingerprint}")
print(f"Results Table: {RESULTS_TABLE}")

try:
    dbutils.notebook.exit(exit_status)
except:
    if not validation_passed:
        raise Exception("UAT Validation Failed: Model does not meet quality thresholds")