In [None]:
# Databricks notebook source
# =============================================================================
# üß™ UAT MODEL INFERENCE - CONFIG DRIVEN (COMPLETE FIXED VERSION)
# =============================================================================
# Purpose: Validate staging model performance on UAT data
# Now reads from pipeline_config.yml - No hardcoding!
# Prerequisites: Run 04_uat_staging.py first
# =============================================================================

# COMMAND ----------
%pip install xgboost requests

# COMMAND ----------
# üîÑ Restart Python to use updated packages
dbutils.library.restartPython()

# COMMAND ----------
import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import numpy as np
import math
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from pyspark.sql import SparkSession
from datetime import datetime
import warnings
import sys
import os
import requests
import traceback
import yaml

warnings.filterwarnings("ignore")

print("=" * 80)
print("üß™ UAT MODEL INFERENCE (CONFIG-DRIVEN)")
print("=" * 80)

# =============================================================================
# ‚úÖ LOAD PIPELINE CONFIGURATION (Dynamic Path)
# =============================================================================
print("\nüìã Loading pipeline configuration from pipeline_config.yml...")

try:
    # Dynamically detect project structure
    try:
        current_dir = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        current_dir = os.getcwd()
    project_root = os.path.abspath(os.path.join(current_dir, ".."))

    # First try current folder (uat_env)
    config_path = os.path.join(current_dir, "pipeline_config.yml")

    # If not found, try dev_env folder
    if not os.path.exists(config_path):
        config_path = os.path.join(project_root, "dev_env", "pipeline_config.yml")

    # Load the YAML config
    with open(config_path, "r") as f:
        pipeline_cfg = yaml.safe_load(f)

    print(f"‚úÖ Loaded pipeline_config.yml from: {config_path}")

    # Model Configuration
    MODEL_TYPE = pipeline_cfg["model"]["type"]
    UC_CATALOG = pipeline_cfg["model"]["catalog"]
    UC_SCHEMA = pipeline_cfg["model"]["schema"]
    BASE_NAME = pipeline_cfg["model"]["base_name"]

    # Auto-generate model name (same as all other scripts)
    MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}_uc2"

    STAGING_ALIAS = pipeline_cfg["aliases"]["staging"]

    # Data Configuration
    DELTA_INPUT_TABLE = pipeline_cfg["data"]["input_table"]
    FEATURE_COLS = pipeline_cfg["data"]["features"]
    LABEL_COL = pipeline_cfg["data"]["label"]

    # UAT Thresholds
    MAPE_THRESHOLD = pipeline_cfg["uat"]["mape_threshold"]
    R2_THRESHOLD = pipeline_cfg["uat"]["r2_threshold"]

    # Output Table
    OUTPUT_TABLE = pipeline_cfg["tables"]["uat_results"]

    print(f"‚úÖ Configuration loaded successfully!")
    print(f"\nüìä Configuration Details:")
    print(f"   Model Type: {MODEL_TYPE.upper()}")
    print(f"   Model Name: {MODEL_NAME}")
    print(f"   Staging Alias: @{STAGING_ALIAS}")
    print(f"   Input Table: {DELTA_INPUT_TABLE}")
    print(f"   Output Table: {OUTPUT_TABLE}")
    print(f"   Features: {FEATURE_COLS}")
    print(f"   Label: {LABEL_COL}")
    print(f"   MAPE Threshold: ‚â§ {MAPE_THRESHOLD}%")
    print(f"   R¬≤ Threshold: ‚â• {R2_THRESHOLD}")

except FileNotFoundError:
    print("‚ùå ERROR: pipeline_config.yml not found!")
    print("üí° Please create pipeline_config.yml in the same directory or in dev_env/")
    sys.exit(1)
except Exception as e:
    print(f"‚ùå ERROR loading configuration: {e}")
    traceback.print_exc()
    sys.exit(1)

print("=" * 80)

# =============================================================================
# ‚úÖ SLACK NOTIFICATION SETUP
# =============================================================================
def get_slack_webhook():
    """Retrieve Slack webhook from secrets with fallback scopes"""
    for scope in ["shared-scope", "dev-scope"]:
        try:
            webhook = dbutils.secrets.get(scope, "SLACK_WEBHOOK_URL")
            if webhook and webhook.strip():
                print(f"‚úì Slack webhook configured from scope '{scope}'")
                return webhook
        except Exception as e:
            print(f"‚ö†Ô∏è Slack webhook not found in scope '{scope}': {e}")
    return None

SLACK_WEBHOOK_URL = get_slack_webhook()

def send_slack_notification(message, level="info"):
    """Send notification to Slack channel"""
    if not SLACK_WEBHOOK_URL:
        print(f"‚ö†Ô∏è Slack webhook not configured")
        print(f"üì¢ Message: {message}")
        return
    
    emoji_map = {
        "info": "‚ÑπÔ∏è",
        "success": "‚úÖ",
        "warning": "‚ö†Ô∏è",
        "error": "‚ùå"
    }
    
    formatted_message = f"{emoji_map.get(level, '‚ÑπÔ∏è')} {message}"
    
    try:
        response = requests.post(
            SLACK_WEBHOOK_URL, 
            json={"text": formatted_message},
            timeout=5
        )
        if response.status_code == 200:
            print(f"‚úÖ Slack notification sent: {level}")
        else:
            print(f"‚ö†Ô∏è Slack notification failed: {response.status_code}")
    except Exception as e:
        print(f"‚ö†Ô∏è Error sending Slack notification: {e}")

# =============================================================================
# ‚úÖ INITIALIZATION
# =============================================================================
spark = SparkSession.builder.appName("UAT_Inference").getOrCreate()
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

print("\n‚úÖ MLflow and Spark initialized")

# =============================================================================
# ‚úÖ STEP 1: LOAD MODEL FROM STAGING ALIAS
# =============================================================================
def load_staging_model():
    """Load model from Unity Catalog using Staging alias"""
    print(f"\n{'='*80}")
    print(f"üìã STEP 1: Loading Model from @{STAGING_ALIAS}")
    print(f"{'='*80}")
    
    try:
        # Try direct alias lookup
        print(f"‚è≥ Attempting to load: models:/{MODEL_NAME}@{STAGING_ALIAS}")
        
        try:
            model_version = client.get_model_version_by_alias(MODEL_NAME, STAGING_ALIAS)
            version = model_version.version
            run_id = model_version.run_id
            
            print(f"‚úÖ Found model with @{STAGING_ALIAS} alias")
            print(f"   Version: v{version}")
            print(f"   Run ID: {run_id}")
            
        except Exception as e:
            print(f"‚ö†Ô∏è Direct alias lookup failed: {e}")
            print(f"   Trying alternative search method...")
            
            # Method 2: Search through all versions
            model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")
            
            if not model_versions:
                raise ValueError(
                    f"‚ùå No model versions found for {MODEL_NAME}\n"
                    f"üí° Solution: Run Model_Registration script first"
                )
            
            # Filter versions with the staging alias
            staging_versions = []
            print(f"\nüîç Searching through {len(model_versions)} version(s)...")
            
            for v in model_versions:
                full_version = client.get_model_version(MODEL_NAME, v.version)
                version_aliases = full_version.aliases if full_version.aliases else []
                
                # Case-insensitive comparison
                if any(alias.lower() == STAGING_ALIAS.lower() for alias in version_aliases):
                    staging_versions.append(full_version)
                    print(f"   ‚úì Version v{v.version} has @{STAGING_ALIAS} alias")
            
            if not staging_versions:
                # List available versions for debugging
                print(f"\n‚ùå No model with alias '@{STAGING_ALIAS}' found!")
                print(f"\nüìã Available versions for {MODEL_NAME}:")
                for v in model_versions[:10]:
                    full_v = client.get_model_version(MODEL_NAME, v.version)
                    v_aliases = full_v.aliases if full_v.aliases else ["No aliases"]
                    print(f"   Version v{v.version}: Aliases = {v_aliases}")
                
                raise ValueError(
                    f"\n‚ùå No model with alias '@{STAGING_ALIAS}' found\n"
                    f"üí° Solution: Run 04_uat_staging.py first"
                )
            
            # Get latest version from staging
            model_version = max(staging_versions, key=lambda x: int(x.version))
            version = model_version.version
            run_id = model_version.run_id
            
            print(f"\n‚úÖ Found {len(staging_versions)} version(s) with @{STAGING_ALIAS} alias")
            print(f"   Loading latest: v{version}")
        
        # Load the model
        model_uri = f"models:/{MODEL_NAME}@{STAGING_ALIAS}"
        print(f"\n‚è≥ Loading model...")
        model = mlflow.pyfunc.load_model(model_uri)

        print(f"\n{'='*80}")
        print("‚úÖ MODEL LOADED SUCCESSFULLY")
        print(f"{'='*80}")
        print(f"   Model: {MODEL_NAME}")
        print(f"   Model Type: {MODEL_TYPE.upper()}")
        print(f"   Version: v{version}")
        print(f"   Run ID: {run_id}")
        print(f"   Status: {model_version.status}")
        
        # Get metric from tags if available
        metric_tag = model_version.tags.get("metric_rmse", "N/A")
        print(f"   Training RMSE: {metric_tag}")
        print(f"{'='*80}\n")
        
        return model, version, run_id

    except Exception as e:
        print(f"\n{'='*80}")
        print("‚ùå FAILED TO LOAD MODEL")
        print(f"{'='*80}")
        print(f"Error: {e}")
        print(f"\nüí° Troubleshooting Steps:")
        print(f"   1. Verify model exists: {MODEL_NAME}")
        print(f"   2. Run 04_uat_staging.py to promote a model to @{STAGING_ALIAS}")
        print(f"   3. Verify alias is exactly '{STAGING_ALIAS}' (case-sensitive)")
        print(f"{'='*80}\n")
        traceback.print_exc()
        raise

# =============================================================================
# ‚úÖ STEP 2: LOAD UAT DATA
# =============================================================================
def load_uat_data():
    """Load UAT data from Delta table"""
    print(f"\n{'='*80}")
    print("üìã STEP 2: Loading UAT Data")
    print(f"{'='*80}")
    
    try:
        print(f"   Loading from: {DELTA_INPUT_TABLE}")
        df_spark = spark.table(DELTA_INPUT_TABLE)
        df = df_spark.toPandas()

        print(f"   Total rows: {len(df)}")
        print(f"   Columns: {list(df.columns)}")

        # Validate required columns
        missing_features = [col for col in FEATURE_COLS if col not in df.columns]
        if missing_features:
            raise ValueError(f"Missing feature columns: {missing_features}")

        if LABEL_COL not in df.columns:
            raise ValueError(f"Missing label column: {LABEL_COL}")

        # Select features and labels
        X = df[FEATURE_COLS]
        y_true = df[LABEL_COL]

        print(f"\n{'='*80}")
        print("‚úÖ DATA LOADED SUCCESSFULLY")
        print(f"{'='*80}")
        print(f"   Features shape: {X.shape}")
        print(f"   Labels shape: {y_true.shape}")
        print(f"   Sample features:\n{X.head(3)}")
        print(f"{'='*80}\n")
        
        return df, X, y_true

    except Exception as e:
        error_msg = str(e)
        print(f"\n{'='*80}")
        print("‚ùå FAILED TO LOAD DATA")
        print(f"{'='*80}")
        
        if "TABLE_OR_VIEW_NOT_FOUND" in error_msg or "cannot be found" in error_msg:
            print(f"   Delta table '{DELTA_INPUT_TABLE}' does not exist")
            print(f"\nüí° Solution:")
            print(f"   1. Create the table first")
            print(f"   2. Verify the table name in pipeline_config.yml")
        else:
            print(f"   Error: {e}")
        
        print(f"{'='*80}\n")
        traceback.print_exc()
        raise

# =============================================================================
# ‚úÖ STEP 3: RUN INFERENCE
# =============================================================================
def run_inference(model, X):
    """Run model inference on UAT data"""
    print(f"\n{'='*80}")
    print("üìã STEP 3: Running Inference")
    print(f"{'='*80}")
    
    try:
        print(f"   Generating predictions for {len(X)} samples...")
        y_pred = model.predict(X)
        
        print(f"\n{'='*80}")
        print("‚úÖ INFERENCE COMPLETE")
        print(f"{'='*80}")
        print(f"   Predictions generated: {len(y_pred)}")
        print(f"   Sample predictions: {y_pred[:5]}")
        print(f"   Min: {y_pred.min():.2f}, Max: {y_pred.max():.2f}, Mean: {y_pred.mean():.2f}")
        print(f"{'='*80}\n")
        
        return y_pred
        
    except Exception as e:
        print(f"\n{'='*80}")
        print("‚ùå INFERENCE FAILED")
        print(f"{'='*80}")
        print(f"   Error: {e}")
        print(f"{'='*80}\n")
        traceback.print_exc()
        raise

# =============================================================================
# ‚úÖ STEP 4: CALCULATE METRICS
# =============================================================================
def evaluate_model(y_true, y_pred):
    """Calculate evaluation metrics"""
    print(f"\n{'='*80}")
    print("üìã STEP 4: Evaluating Model Performance")
    print(f"{'='*80}")
    
    try:
        mae = mean_absolute_error(y_true, y_pred)
        rmse = math.sqrt(mean_squared_error(y_true, y_pred))
        r2 = r2_score(y_true, y_pred)
        mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100

        print(f"\nüìä Evaluation Metrics:")
        print(f"   MAE  : {mae:>12,.2f}")
        print(f"   RMSE : {rmse:>12,.2f}")
        print(f"   R¬≤   : {r2:>12.4f}")
        print(f"   MAPE : {mape:>12.2f}%")
        print(f"{'='*80}\n")
        
        return mae, rmse, r2, mape
        
    except Exception as e:
        print(f"\n‚ùå Evaluation failed: {e}")
        traceback.print_exc()
        raise

# =============================================================================
# ‚úÖ STEP 5: UAT VALIDATION
# =============================================================================
def validate_uat(mape, r2, model_version):
    """Validate model against UAT thresholds"""
    print(f"\n{'='*80}")
    print("üìã STEP 5: UAT Validation")
    print(f"{'='*80}")

    print(f"\nüìè Validation Thresholds:")
    print(f"   MAPE: ‚â§ {MAPE_THRESHOLD}%")
    print(f"   R¬≤:   ‚â• {R2_THRESHOLD}")

    print(f"\nüìä Actual Performance:")
    mape_pass = mape <= MAPE_THRESHOLD
    r2_pass = r2 >= R2_THRESHOLD
    
    print(f"   MAPE: {mape:.2f}% {'‚úÖ PASS' if mape_pass else '‚ùå FAIL'}")
    print(f"   R¬≤:   {r2:.4f}  {'‚úÖ PASS' if r2_pass else '‚ùå FAIL'}")

    if mape_pass and r2_pass:
        print(f"\n{'='*80}")
        print("‚úÖ‚úÖ UAT PASSED ‚úÖ‚úÖ")
        print(f"{'='*80}")
        print(f"   Model v{model_version} is ready for production!")
        print(f"{'='*80}\n")

        send_slack_notification(
            f"‚úÖ Model `{MODEL_NAME}` (Type: {MODEL_TYPE.upper()}) v{model_version} PASSED UAT\n"
            f"üìä MAPE: {mape:.2f}%, R¬≤: {r2:.4f}\n"
            f"üöÄ Ready for production promotion!",
            level="success"
        )
        return "PASSED"
    else:
        print(f"\n{'='*80}")
        print("‚ùå‚ùå UAT FAILED ‚ùå‚ùå")
        print(f"{'='*80}")

        fail_reasons = []
        if not mape_pass:
            fail_reasons.append(f"MAPE too high ({mape:.2f}% > {MAPE_THRESHOLD}%)")
        if not r2_pass:
            fail_reasons.append(f"R¬≤ too low ({r2:.4f} < {R2_THRESHOLD})")

        print(f"   Failure reasons:")
        for reason in fail_reasons:
            print(f"   ‚Ä¢ {reason}")
        print(f"{'='*80}\n")

        send_slack_notification(
            f"‚ùå Model `{MODEL_NAME}` (Type: {MODEL_TYPE.upper()}) v{model_version} FAILED UAT\n"
            f"üìä MAPE: {mape:.2f}%, R¬≤: {r2:.4f}\n"
            f"üö´ Reasons: {', '.join(fail_reasons)}",
            level="error"
        )

        return "FAILED"

# =============================================================================
# ‚úÖ STEP 6: LOG RESULTS  (FIXED VERSION)
# =============================================================================
def log_results(model_version, run_id, mae, rmse, r2, mape, status):
    """Log UAT results to Delta table (type-safe, merge-safe)"""
    print(f"\n{'='*80}")
    print("üìã STEP 6: Logging Results")
    print(f"{'='*80}")
    
    try:
        # Check if table exists
        table_exists = False
        existing_df = None
        
        try:
            existing_df = spark.table(OUTPUT_TABLE).toPandas()
            table_exists = True
            print(f"   Table exists: Yes  |  Rows: {len(existing_df)}")
        except Exception:
            print(f"   Table exists: No (will be created)")

        # Prepare result data
        result_df = pd.DataFrame([{
            "timestamp": datetime.now(),
            "model_name": MODEL_NAME,
            "model_type": MODEL_TYPE,
            "model_version": str(model_version),   # üîß Cast to string to match schema
            "run_id": run_id,
            "mae": float(mae),
            "rmse": float(rmse),
            "r2": float(r2),
            "mape": float(mape),
            "uat_status": status,
            "mape_threshold": float(MAPE_THRESHOLD),
            "r2_threshold": float(R2_THRESHOLD)
        }])

        # Convert to Spark DF
        spark_df = spark.createDataFrame(result_df)

        # üîß Force schema alignment with existing table if present
        if table_exists:
            target_schema = spark.table(OUTPUT_TABLE).schema
            # Convert model_version to string if target expects string
            target_field = next((f for f in target_schema if f.name == "model_version"), None)
            if target_field and str(target_field.dataType).lower().startswith("string"):
                spark_df = spark_df.withColumn("model_version", spark_df["model_version"].cast("string"))
            else:
                spark_df = spark_df.withColumn("model_version", spark_df["model_version"].cast("int"))

            # Append safely
            spark_df.write.mode("append").option("mergeSchema", "true").saveAsTable(OUTPUT_TABLE)
        else:
            spark_df.write.mode("append").saveAsTable(OUTPUT_TABLE)

        print(f"\n{'='*80}")
        print("‚úÖ RESULTS LOGGED SUCCESSFULLY")
        print(f"{'='*80}")
        print(f"   Output Table: {OUTPUT_TABLE}")
        print(f"   Model: {MODEL_NAME}")
        print(f"   Model Type: {MODEL_TYPE.upper()}")
        print(f"   Version: v{model_version}")
        print(f"   UAT Status: {status}")
        print(f"{'='*80}\n")

    except Exception as e:
        print(f"\n‚ö†Ô∏è Failed to log results: {e}")
        traceback.print_exc()


# =============================================================================
# ‚úÖ MAIN EXECUTION
# =============================================================================
def main():
    """Main UAT inference pipeline"""
    try:
        print("\n" + "="*80)
        print("üé¨ STARTING UAT INFERENCE PIPELINE")
        print("="*80 + "\n")

        # Execute pipeline steps
        model, model_version, run_id = load_staging_model()
        df, X, y_true = load_uat_data()
        y_pred = run_inference(model, X)
        mae, rmse, r2, mape = evaluate_model(y_true, y_pred)
        status = validate_uat(mape, r2, model_version)
        log_results(model_version, run_id, mae, rmse, r2, mape, status)

        # Final summary
        print("\n" + "="*80)
        print("‚ú® UAT INFERENCE COMPLETED SUCCESSFULLY ‚ú®")
        print("="*80)
        print(f"\nüìä Final Summary:")
        print(f"   Model: {MODEL_NAME}")
        print(f"   Model Type: {MODEL_TYPE.upper()}")
        print(f"   Version: v{model_version}")
        print(f"   Run ID: {run_id}")
        print(f"   UAT Status: {status}")
        print(f"   Metrics:")
        print(f"     ‚Ä¢ RMSE: {rmse:,.2f}")
        print(f"     ‚Ä¢ MAPE: {mape:.2f}%")
        print(f"     ‚Ä¢ R¬≤:   {r2:.4f}")
        print(f"     ‚Ä¢ MAE:  {mae:,.2f}")
        
        if status == "PASSED":
            print(f"\nüìå Next Step:")
            print(f"   Run 06_production_promotion.py to promote to production")
        
        print("="*80 + "\n")

        # Save for workflow
        try:
            dbutils.jobs.taskValues.set(key="uat_status", value=status)
            dbutils.jobs.taskValues.set(key="uat_mape", value=mape)
            dbutils.jobs.taskValues.set(key="uat_r2", value=r2)
            print("‚úÖ Task values saved for workflow")
        except:
            print("‚ÑπÔ∏è Not running in workflow - skipping task values")

    except Exception as e:
        print("\n" + "="*80)
        print("‚ùå UAT INFERENCE FAILED")
        print("="*80)
        print(f"Error: {str(e)}")
        print("="*80 + "\n")
        
        send_slack_notification(
            f"‚ùå UAT pipeline failed for `{MODEL_NAME}` (Type: {MODEL_TYPE.upper()})\n"
            f"Error: {str(e)}",
            level="error"
        )
        
        sys.exit(1)

# =============================================================================
# ‚úÖ EXECUTE
# =============================================================================
if __name__ == "__main__":
    main()