In [None]:
import mlflow
from mlflow.tracking import MlflowClient
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import sys
import os
import warnings
import re # New import for regex

# Suppress warnings
warnings.filterwarnings("ignore")
os.environ['PYTHONWARNINGS'] = 'ignore'

print("=" * 70)
print("UAT MODEL INFERENCE - STAGING VALIDATION")
print("=" * 70)

# =============================================================================
# CONFIGURATION
# =============================================================================
# Unity Catalog Model Details
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
MODEL_NAME = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_model_uc"

# Data Configuration
DATA_CATALOG_NAME = "workspace"
DATA_SCHEMA_NAME = "default"
TABLE_NAME = "house_price_delta"
FULL_TABLE_NAME = f"{DATA_CATALOG_NAME}.{DATA_SCHEMA_NAME}.{TABLE_NAME}"

# MLflow Experiment
EXPERIMENT_NAME = "/Shared/House_Price_Prediction_Delta_RF"

# Feature Configuration
FEATURE_COLUMNS = ['sq_feet', 'num_bedrooms', 'num_bathrooms', 'year_built', 'location_score']
LABEL_COLUMN = 'price'

# UAT Validation Thresholds
MAX_ACCEPTABLE_MAPE = 15.0 	# Maximum 15% error
MIN_ACCEPTABLE_R2 = 0.75 	# Minimum R2 score

# =============================================================================
# SPARK SESSION INITIALIZATION
# =============================================================================
try:
	spark = SparkSession.builder.appName("UAT_ModelInference").getOrCreate()
	print("Spark session initialized")
except Exception as e:
	print(f"Error initializing Spark: {e}")
	sys.exit(1)

# =============================================================================
# GET MODEL ALIAS FROM WIDGET
# =============================================================================
try:
	model_alias = dbutils.widgets.get("alias")
	print(f"Model Alias from widget: {model_alias}")
except:
	model_alias = "Staging"
	print(f"Widget not found, using default: {model_alias}")

# =============================================================================
# LOAD MODEL FROM MLFLOW RUN (Community Edition Compatible)
# =============================================================================
print(f"\nLoading model for UAT validation...")
print(f"Target Alias: {model_alias}")

try:
	client = MlflowClient()
	
	# Get experiment
	experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
	if not experiment:
		print(f"Error: Experiment '{EXPERIMENT_NAME}' not found")
		sys.exit(1)
	
	print(f"Experiment: {experiment.name}")
	
	# Get latest successful run
	runs = client.search_runs(
		experiment_ids=[experiment.experiment_id],
		filter_string="status = 'FINISHED'",
		order_by=["start_time DESC"],
		max_results=1
	)
	
	if not runs:
		print("Error: No successful runs found")
		print("Please run the training script first")
		sys.exit(1)
	
	latest_run = runs[0]
	run_id = latest_run.info.run_id
	
	print(f"\nModel Details:")
	print(f" 	Run ID: {run_id}")
	print(f" 	Run Name: {latest_run.info.run_name}")
	
	# Display training parameters
	print(f"\n 	Training Parameters:")
	for key, value in latest_run.data.params.items():
		print(f" 	 	{key}: {value}")
	
	# Display training metrics
	print(f"\n 	Training Metrics:")
	training_metrics = {}
	for key, value in latest_run.data.metrics.items():
		print(f" 	 	{key}: {value:.4f}")
		training_metrics[key] = value
	
	# Load model from run
	model_uri = f"runs:/{run_id}/sklearn_rf_model"
	print(f"\nLoading model from: {model_uri}")
	
	model = mlflow.sklearn.load_model(model_uri)
	print("Model loaded successfully for UAT validation")
	
except Exception as e:
	print(f"\nError loading model: {e}")
	import traceback
	traceback.print_exc()
	sys.exit(1)

# =============================================================================
# LOAD TEST DATA FROM DELTA TABLE
# =============================================================================
print(f"\nLoading test data from: {FULL_TABLE_NAME}")

try:
	# Load data from Delta table
	spark_df = spark.read.format("delta").table(FULL_TABLE_NAME)
	row_count = spark_df.count()
	
	print(f"Data loaded: {row_count} rows")
	
	# Verify required columns exist
	available_columns = spark_df.columns
	missing_columns = [col for col in FEATURE_COLUMNS + [LABEL_COLUMN] 
					  if col not in available_columns]
	
	if missing_columns:
		print(f"\nError: Missing columns: {missing_columns}")
		print(f"Available columns: {available_columns}")
		sys.exit(1)
	
	print("All required columns present")
	
	# Convert to Pandas for inference
	pandas_df = spark_df.select(*FEATURE_COLUMNS, LABEL_COLUMN).toPandas()
	print(f"Converted to Pandas: {pandas_df.shape}")
	
except Exception as e:
	print(f"\nError loading data: {e}")
	sys.exit(1)

# =============================================================================
# MAKE PREDICTIONS
# =============================================================================
print(f"\n{'=' * 70}")
print("RUNNING INFERENCE ON UAT DATA")
print(f"{'=' * 70}")

try:
	# Extract features and labels
	X_test = pandas_df[FEATURE_COLUMNS]
	y_actual = pandas_df[LABEL_COLUMN]
	
	# Make predictions
	y_predicted = model.predict(X_test)
	
	print(f"\nPredictions completed: {len(y_predicted)} samples")
	
	# Add predictions to dataframe
	pandas_df['predicted_price'] = y_predicted
	pandas_df['prediction_error'] = y_actual - y_predicted
	pandas_df['absolute_error'] = abs(pandas_df['prediction_error'])
	pandas_df['percentage_error'] = (pandas_df['absolute_error'] / y_actual) * 100
	
except Exception as e:
	print(f"\nError during prediction: {e}")
	import traceback
	traceback.print_exc()
	sys.exit(1)

# =============================================================================
# CALCULATE PERFORMANCE METRICS
# =============================================================================
print(f"\n{'=' * 70}")
print("UAT VALIDATION METRICS")
print(f"{'=' * 70}")

try:
	# Calculate metrics
	mae = mean_absolute_error(y_actual, y_predicted)
	rmse = np.sqrt(mean_squared_error(y_actual, y_predicted))
	r2 = r2_score(y_actual, y_predicted)
	mape = (abs(y_actual - y_predicted) / y_actual * 100).mean()
	
	# Additional statistics
	median_error = pandas_df['absolute_error'].median()
	max_error = pandas_df['absolute_error'].max()
	min_error = pandas_df['absolute_error'].min()
	
	# Print metrics
	print(f"\nRegression Metrics:")
	print(f" 	Mean Absolute Error (MAE): 	${mae:,.2f}")
	print(f" 	Root Mean Squared Error: 	 	 ${rmse:,.2f}")
	print(f" 	R² Score: 	 	 	 	 	 	{r2:.4f}")
	print(f" 	Mean Absolute % Error: 	 	 	 {mape:.2f}%")
	
	print(f"\nError Statistics:")
	print(f" 	Median Absolute Error: 	 	 	 ${median_error:,.2f}")
	print(f" 	Maximum Error: 	 	 	 	 	 ${max_error:,.2f}")
	print(f" 	Minimum Error: 	 	 	 	 	 ${min_error:,.2f}")
	
	print(f"\nPrediction Statistics:")
	print(f" 	Actual Price Range: 	 ${y_actual.min():,.2f} - ${y_actual.max():,.2f}")
	print(f" 	Predicted Range: 	 	${y_predicted.min():,.2f} - ${y_predicted.max():,.2f}")
	print(f" 	Mean Actual Price: 	 	${y_actual.mean():,.2f}")
	print(f" 	Mean Predicted Price: ${y_predicted.mean():,.2f}")
	
except Exception as e:
	print(f"\nError calculating metrics: {e}")
	sys.exit(1)

# =============================================================================
# UAT VALIDATION - PASS/FAIL CRITERIA
# =============================================================================
print(f"\n{'=' * 70}")
print("UAT VALIDATION RESULTS")
print(f"{'=' * 70}")

validation_passed = True
validation_results = []

# Check MAPE threshold
if mape <= MAX_ACCEPTABLE_MAPE:
	validation_results.append(f"PASS: MAPE {mape:.2f}% <= {MAX_ACCEPTABLE_MAPE}%")
else:
	validation_results.append(f"FAIL: MAPE {mape:.2f}% > {MAX_ACCEPTABLE_MAPE}%")
	validation_passed = False

# Check R² threshold
if r2 >= MIN_ACCEPTABLE_R2:
	validation_results.append(f"PASS: R² {r2:.4f} >= {MIN_ACCEPTABLE_R2}")
else:
	validation_results.append(f"FAIL: R² {r2:.4f} < {MIN_ACCEPTABLE_R2}")
	validation_passed = False

# Print validation results
print(f"\nValidation Criteria:")
for result in validation_results:
	status = result.split(":")[0]
	if status == "PASS":
		print(f" 	{result}")
	else:
		print(f" 	{result}")

# Final verdict
print(f"\n{'=' * 70}")
if validation_passed:
	print("UAT VALIDATION: PASSED")
	print("Model is ready for promotion to Production")
else:
	print("UAT VALIDATION: FAILED")
	print("Model does not meet quality thresholds")
	print("Please retrain with better parameters or more data")
print(f"{'=' * 70}")

# =============================================================================
# DISPLAY SAMPLE PREDICTIONS
# =============================================================================
print(f"\nSample Predictions (First 10 rows):")

sample_df = pandas_df[[*FEATURE_COLUMNS, LABEL_COLUMN, 'predicted_price', 
						'absolute_error', 'percentage_error']].head(10).copy()

# Round values for better readability (without $ sign)
sample_df['price'] = sample_df['price'].round(2)
sample_df['predicted_price'] = sample_df['predicted_price'].round(2)
sample_df['absolute_error'] = sample_df['absolute_error'].round(2)
sample_df['percentage_error'] = sample_df['percentage_error'].round(2)

print(sample_df.to_string(index=False))

# =============================================================================
# INTELLIGENT SAVE - PERFORMANCE-BASED DEDUPLICATION CHECK (MODIFIED LOGIC)
# =============================================================================
SAVE_RESULTS = True

# FIX: Base table name for versioning (e.g., uat_predictions_staging)
BASE_TABLE_NAME = f"uat_predictions_{model_alias.lower()}"
BASE_OUTPUT_TABLE_PATH = f"{DATA_CATALOG_NAME}.{DATA_SCHEMA_NAME}"

# Helper function to find the latest table version and its fingerprint
def get_latest_version_info(spark, base_name, catalog, schema):
    """Finds the highest numbered table and its last saved fingerprint."""
    try:
        # Use SQL to list relevant tables that follow the naming convention (base_name_X)
        # We also check the unversioned table (base_name) as the first version
        table_list_df = spark.sql(f"SHOW TABLES IN {catalog}.{schema} LIKE '{base_name}%'").toPandas()
        
        latest_version = 0
        latest_table_name = f"{catalog}.{schema}.{base_name}" # Default to unversioned table
        latest_fingerprint = None
        
        # Dictionary to store {version: full_table_name}
        version_map = {0: latest_table_name}
        
        # Parse existing tables for version numbers
        for _, row in table_list_df.iterrows():
            table_name = row['tableName']
            
            # Check for pattern BASE_TABLE_NAME_X
            match = re.match(rf"^{base_name}_(\d+)$", table_name, re.IGNORECASE)
            if match:
                version = int(match.group(1))
                version_map[version] = f"{catalog}.{schema}.{table_name}"
                if version > latest_version:
                    latest_version = version
                    
        # Update latest_table_name to the highest found version (0 is unversioned base)
        if latest_version > 0:
            latest_table_name = version_map[latest_version]
        
        # Now, attempt to read the latest table to get its fingerprint
        if latest_table_name:
            try:
                # Get the latest run's fingerprint and other UAT metrics from the existing table
                latest_table_df = spark.read.format("delta").table(latest_table_name) \
                    .select(col('run_fingerprint')) \
                    .orderBy(col('saved_timestamp').desc()) \
                    .limit(1).collect()
                
                if latest_table_df:
                    latest_fingerprint = latest_table_df[0]['run_fingerprint']
                    
            except Exception as read_error:
                # Table exists but cannot be read or lacks fingerprint column
                print(f"Warning: Could not read fingerprint from {latest_table_name}: {read_error}")

        return latest_version, latest_table_name, latest_fingerprint

    except Exception as e:
        print(f"Error listing tables in {catalog}.{schema}: {e}")
        # Fallback if listing fails, assume base table is version 0
        return 0, f"{catalog}.{schema}.{base_name}", None 


if SAVE_RESULTS:
	import hashlib
	import json
	from datetime import datetime
	
	
	print(f"\nChecking if results need to be saved...")
	
	try:
		# 1. Create unique fingerprint
		fingerprint_data = {
			'model_params': {
				'best_n_estimators': latest_run.data.params.get('best_n_estimators'),
				'best_max_depth': latest_run.data.params.get('best_max_depth'),
				'best_min_samples_split': latest_run.data.params.get('best_min_samples_split'),
				'best_min_samples_leaf': latest_run.data.params.get('best_min_samples_leaf')
			},
			'training_metrics': {
				'test_rmse': round(training_metrics.get('test_rmse', 0), 2),
				'test_r2_score': round(training_metrics.get('test_r2_score', 0), 4),
				'best_cv_rmse': round(training_metrics.get('best_cv_rmse', 0), 2)
			},
			'uat_metrics': {
				'rmse': round(rmse, 2),
				'r2': round(r2, 4),
				'mae': round(mae, 2),
				'mape': round(mape, 2)
			},
			'data_size': len(pandas_df),
			'feature_columns': sorted(FEATURE_COLUMNS),
			'alias': model_alias,
			'run_id': run_id
		}
		
		# Generate hash
		fingerprint_str = json.dumps(fingerprint_data, sort_keys=True)
		current_fingerprint = hashlib.md5(fingerprint_str.encode()).hexdigest()
		
		print(f"Current Run Fingerprint: {current_fingerprint}")
		
		# 2. Check existing tables and get saving decision
		latest_version, latest_table_name, latest_fingerprint = get_latest_version_info(
			spark, BASE_TABLE_NAME, DATA_CATALOG_NAME, DATA_SCHEMA_NAME
		)

		print(f"Latest existing table found: {latest_table_name} (Version {latest_version})")
		print(f"Latest saved fingerprint: {latest_fingerprint}")

		# Default saving mode and table
		save_mode = "overwrite"
		output_table = latest_table_name
		change_reason = ""
		
		# Decision Logic
		if latest_fingerprint is None:
			# Case 1: No previous table/fingerprint found (first run). Use base table name (version 0).
			output_table = f"{BASE_OUTPUT_TABLE_PATH}.{BASE_TABLE_NAME}_1" # Start at version 1 for clarity
			save_mode = "append" # Use append for initial creation
			change_reason = "Initial save: creating version 1"
			print("\nACTION: Creating first version table.")

		elif current_fingerprint == latest_fingerprint:
			# Case 2: Fingerprint matches (Performance/Parameters SAME) -> Overwrite existing latest table
			output_table = latest_table_name
			save_mode = "overwrite"
			change_reason = f"Identical performance/params. Overwriting existing table {latest_version}."
			print(f"\nACTION: Overwriting table {latest_version}. No functional change detected.")
		
		else:
			# Case 3: Fingerprint mismatch (Performance/Parameters CHANGED) -> Create new table version
			new_version = latest_version + 1
			output_table = f"{BASE_OUTPUT_TABLE_PATH}.{BASE_TABLE_NAME}_{new_version}"
			save_mode = "append" # Use append for initial creation
			change_reason = f"Performance or parameters changed. Creating new table version {new_version}."
			print(f"\nACTION: Creating new version table ({new_version}). Change detected.")
			
			
		print(f"\n{'='*70}")
		print(f"SAVING RESULTS")
		print(f"{'='*70}")
		print(f"Target Table: {output_table}")
		print(f"Save Mode: {save_mode}")
		print(f"Reason: {change_reason}")
		
		# 3. Save Data
		# Add metadata columns
		pandas_df['run_fingerprint'] = current_fingerprint
		pandas_df['run_id'] = run_id
		pandas_df['saved_timestamp'] = datetime.now()
		pandas_df['model_alias'] = model_alias
		
		# Save all parameters
		pandas_df['best_n_estimators'] = int(latest_run.data.params.get('best_n_estimators', 0))
		pandas_df['best_max_depth'] = int(latest_run.data.params.get('best_max_depth', 0))
		pandas_df['best_min_samples_split'] = int(latest_run.data.params.get('best_min_samples_split', 2))
		pandas_df['best_min_samples_leaf'] = int(latest_run.data.params.get('best_min_samples_leaf', 1))
		
		# Save training metrics
		pandas_df['training_test_rmse'] = training_metrics.get('test_rmse', 0)
		pandas_df['training_test_r2'] = training_metrics.get('test_r2_score', 0)
		pandas_df['training_cv_rmse'] = training_metrics.get('best_cv_rmse', 0)
		
		# Save UAT metrics
		pandas_df['uat_rmse'] = round(rmse, 2)
		pandas_df['uat_r2'] = round(r2, 4)
		pandas_df['uat_mae'] = round(mae, 2)
		pandas_df['uat_mape'] = round(mape, 2)
		
		pandas_df['validation_status'] = 'PASSED' if validation_passed else 'FAILED'
		pandas_df['change_reason'] = change_reason
		
		# Convert to Spark DataFrame
		result_spark_df = spark.createDataFrame(pandas_df)
		
		# Save to Delta table 
		# If save_mode is 'append', it creates the table if it doesn't exist.
		# If save_mode is 'overwrite', it overwrites the existing table.
		result_spark_df.write \
			.format("delta") \
			.mode(save_mode) \
			.option("overwriteSchema", "true") \
			.saveAsTable(output_table)
		
		print(f"\n✅ UAT results saved successfully!")
		print(f" 	Run ID: {run_id}")
		print(f" 	Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
		print(f" 	Records: {len(pandas_df)}")
		print(f" 	Fingerprint: {current_fingerprint}")
		
	except Exception as e:
		print(f"\n⚠️ Warning: Could not save results: {e}")
		import traceback
		traceback.print_exc()

# =============================================================================
# DISPLAY IN DATABRICKS (if available)
# =============================================================================
try:
	result_spark_df = spark.createDataFrame(pandas_df)
	display(result_spark_df)
except NameError:
	print("\nNote: display() not available outside Databricks notebook")

# =============================================================================
# EXIT WITH APPROPRIATE CODE
# =============================================================================
print(f"\n{'=' * 70}")
print("UAT INFERENCE COMPLETE")
print(f"{'=' * 70}")

if not validation_passed:
	print("\nValidation failed - Model needs improvement")
	# In Databricks notebook, use dbutils instead of sys.exit
	try:
		dbutils.notebook.exit("FAILED")
	except:
		raise Exception("UAT Validation Failed: Model does not meet quality thresholds")
else:
	print("\nModel validated successfully for promotion to Production")
	# Success exit
	try:
		dbutils.notebook.exit("PASSED")
	except:
		pass 	# In non-notebook environment, just continue
