In [None]:
# Databricks notebook source
# =============================================================================
# üéØ MODEL EVALUATION SCRIPT - CONFIG DRIVEN (FIXED)
# =============================================================================
# Purpose: Find best model from experiment and prepare for registration
# Now reads from pipeline_config.yml - No hardcoding!
# =============================================================================

%pip install xgboost requests

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import numpy as np
import sys
import os
import yaml
from datetime import datetime
from pyspark.sql import SparkSession
import traceback
import json

print("=" * 80)
print("üéØ MODEL EVALUATION SYSTEM (CONFIG-DRIVEN)")
print("=" * 80)

# =============================================================================
# ‚úÖ LOAD PIPELINE CONFIGURATION (NEW - REPLACES ALL HARDCODING!)
# =============================================================================
print("\nüìã Loading pipeline configuration from pipeline_config.yml...")

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    
    # Extract configuration values
    MODEL_TYPE = pipeline_cfg["model"]["type"]
    UC_CATALOG = pipeline_cfg["model"]["catalog"]
    UC_SCHEMA = pipeline_cfg["model"]["schema"]
    BASE_NAME = pipeline_cfg["model"]["base_name"]
    
    # Auto-generate model name
    MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}_uc2"
    
    EXPERIMENT_NAME = pipeline_cfg["experiment"]["name"]
    MODEL_ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]
    
    METRIC_KEY = pipeline_cfg["metrics"]["primary_metric"]
    IMPROVEMENT_THRESHOLD = pipeline_cfg["metrics"]["improvement_threshold"]
    
    # Delta Tables
    EVALUATION_LOG_TABLE = pipeline_cfg["tables"]["evaluation_log"]
    BEST_MODEL_METADATA_TABLE = pipeline_cfg["tables"]["best_model_metadata"]
    
    print(f"‚úÖ Pipeline configuration loaded successfully!")
    print(f"\nüìä Configuration Details:")
    print(f"   Model Type: {MODEL_TYPE.upper()}")
    print(f"   Model Name: {MODEL_NAME}")
    print(f"   Experiment: {EXPERIMENT_NAME}")
    print(f"   Metric: {METRIC_KEY} (lower is better)")
    print(f"   Improvement Threshold: {IMPROVEMENT_THRESHOLD * 100}%")
    print(f"   Metadata Table: {BEST_MODEL_METADATA_TABLE}")
    print(f"   Log Table: {EVALUATION_LOG_TABLE}")
    
except FileNotFoundError:
    print("‚ùå ERROR: pipeline_config.yml not found!")
    print("üí° Please create pipeline_config.yml in the same directory")
    sys.exit(1)
except Exception as e:
    print(f"‚ùå ERROR loading configuration: {e}")
    traceback.print_exc()
    sys.exit(1)

print("=" * 80)

# =============================================================================
# ‚úÖ INITIALIZATION
# =============================================================================
try:
    spark = SparkSession.builder.appName("ModelEvaluation").getOrCreate()
    mlflow.set_tracking_uri("databricks")
    mlflow.set_registry_uri("databricks-uc")
    client = MlflowClient()
    print("\n‚úÖ MLflow and Spark initialized")

    # Verify experiment exists
    exp = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
    if exp is None:
        print(f"‚ùå ERROR: Experiment '{EXPERIMENT_NAME}' not found!")
        print("\nüí° Available experiments:")
        all_exps = client.search_experiments(max_results=20)
        for e in all_exps:
            print(f"   - {e.name}")
        print(f"\nüí° Please run training script first to create the experiment")
        sys.exit(1)
    
    print(f"‚úÖ Experiment found: {EXPERIMENT_NAME}")
    print(f"   Experiment ID: {exp.experiment_id}")

except Exception as e:
    print(f"‚ùå Initialization failed: {e}")
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# üìä STEP 1: GET BEST MODEL FROM EXPERIMENT
# =============================================================================
def get_best_model_from_experiment():
    """Find the best performing model from all experiment runs"""
    print(f"\n{'='*70}")
    print("üìã STEP 1: Finding BEST Model From Experiment")
    print(f"{'='*70}")

    try:
        exp = client.get_experiment_by_name(EXPERIMENT_NAME)
        
        # Get all runs sorted by metric (ascending = best first for RMSE)
        all_runs = client.search_runs(
            [exp.experiment_id],
            filter_string=f"metrics.{METRIC_KEY} > 0",
            order_by=[f"metrics.{METRIC_KEY} ASC"],
            max_results=1000
        )

        if not all_runs:
            print(f"\n‚ùå ERROR: No runs found with valid '{METRIC_KEY}' metric!")
            print(f"\nüí° Please run training script first")
            print(f"   Expected experiment: {EXPERIMENT_NAME}")
            return None

        print(f"‚úÖ Total runs in experiment: {len(all_runs)}")

        # Show top 10 models
        print(f"\nüìä Top 10 Models (by {METRIC_KEY}):")
        print(f"{'Rank':<6} {'Run Name':<40} {METRIC_KEY.upper():<15} {'Timestamp':<20}")
        print("-" * 100)

        for i, run in enumerate(all_runs[:10], 1):
            run_name = run.info.run_name or "Unnamed"
            metric_val = run.data.metrics.get(METRIC_KEY, float('inf'))
            timestamp = datetime.fromtimestamp(run.info.start_time/1000).strftime('%Y-%m-%d %H:%M')
            marker = "üëë BEST" if i == 1 else f"{i}."
            print(f"{marker:<6} {run_name:<40} {metric_val:<15.6f} {timestamp}")

        # Select best model
        best_run = all_runs[0]
        run_id = best_run.info.run_id
        run_name = best_run.info.run_name or "Unnamed"
        metrics = best_run.data.metrics
        params = best_run.data.params
        metric_value = metrics.get(METRIC_KEY)

        print(f"\n‚úÖ BEST Model Selected:")
        print(f"   Run ID: {run_id}")
        print(f"   Run Name: {run_name}")
        print(f"   {METRIC_KEY}: {metric_value:.6f}")
        print(f"   Rank: #1 out of {len(all_runs)} runs")
        print(f"   Timestamp: {datetime.fromtimestamp(best_run.info.start_time/1000)}")

        return {
            'run_id': run_id,
            'run_name': run_name,
            'metric_key': METRIC_KEY,
            'metric_value': metric_value,
            'params': params,
            'all_metrics': metrics,
            'timestamp': best_run.info.start_time,
            'total_runs': len(all_runs),
            'model_uri': f"runs:/{run_id}/{MODEL_ARTIFACT_PATH}",
            'artifact_path': MODEL_ARTIFACT_PATH
        }

    except Exception as e:
        print(f"‚ùå Error getting best model: {e}")
        traceback.print_exc()
        return None

# =============================================================================
# üîß HELPER: GET MODEL ALIASES SAFELY
# =============================================================================
def get_model_aliases_safe(model_name, version):
    """Safely get aliases for a model version"""
    try:
        common_aliases = ['production', 'Staging', 'champion', 'baseline']
        found_aliases = []
        
        for alias in common_aliases:
            try:
                alias_version = client.get_model_version_by_alias(model_name, alias)
                if alias_version and str(alias_version.version) == str(version):
                    found_aliases.append(alias)
            except:
                continue
        
        return found_aliases
    except Exception:
        return []

# =============================================================================
# üèÜ STEP 2: GET CURRENT REGISTERED MODEL
# =============================================================================
def get_current_registered_model():
    """Get current registered model from registry"""
    print(f"\n{'='*70}")
    print("üìã STEP 2: Checking Current Registered Model")
    print(f"{'='*70}")
    print(f"   Looking for: {MODEL_NAME}")

    try:
        # Search for model versions
        versions = client.search_model_versions(f"name = '{MODEL_NAME}'")
        
        if not versions:
            print("‚ÑπÔ∏è No models in registry (first model registration)")
            return None
        
        # Convert to list safely
        versions_list = list(versions)
        
        if not versions_list:
            print("‚ÑπÔ∏è No models in registry (first model registration)")
            return None

        print(f"‚úÖ Found {len(versions_list)} existing version(s)")

        # Find best priority version (Production > Staging > Latest)
        best_version = None
        best_priority = 999
        
        for v in versions_list:
            try:
                version_aliases = get_model_aliases_safe(MODEL_NAME, v.version)
                
                priority = 999
                if 'production' in version_aliases:
                    priority = 1
                elif 'Staging' in version_aliases:
                    priority = 2
                elif 'champion' in version_aliases:
                    priority = 3
                else:
                    priority = 10
                
                if priority < best_priority:
                    best_priority = priority
                    best_version = v
                    
                    if priority == 1:
                        print(f"‚úÖ Found Production model: Version {v.version}")
                        break
                    elif priority == 2:
                        print(f"‚úÖ Found Staging model: Version {v.version}")
            except Exception as e:
                print(f"‚ö†Ô∏è Error processing version {v.version}: {e}")
                continue
        
        # If no aliased version, use latest
        if best_priority == 999 and versions_list:
            best_version = versions_list[0]
            print(f"‚úÖ Using latest model: Version {best_version.version}")

        if best_version:
            try:
                run = client.get_run(best_version.run_id)
                metric = run.data.metrics.get(METRIC_KEY)
                final_aliases = get_model_aliases_safe(MODEL_NAME, best_version.version)
                
                print(f"   Version: {best_version.version}")
                print(f"   Run ID: {best_version.run_id}")
                print(f"   {METRIC_KEY}: {metric:.6f}" if metric else "   Metric: N/A")
                print(f"   Aliases: {', '.join(final_aliases) if final_aliases else 'None'}")
                
                return {
                    'version': best_version.version,
                    'run_id': best_version.run_id,
                    'metric_value': metric if metric else 0.0,
                    'aliases': final_aliases
                }
            except Exception as e:
                print(f"‚ö†Ô∏è Error fetching run details: {e}")
                return None
        
        return None

    except Exception as e:
        print(f"‚ÑπÔ∏è No registered model found: {e}")
        print("   (This is expected for first-time registration)")
        return None

# =============================================================================
# üîç STEP 3: EVALUATE MODEL QUALITY
# =============================================================================
def evaluate_model(new_model, current_model):
    """Evaluate if new model should be registered"""
    print(f"\n{'='*70}")
    print("üìã STEP 3: Model Evaluation")
    print(f"{'='*70}")

    # First model - automatic approval
    if not current_model:
        print("‚úÖ APPROVED: First model (no baseline to compare)")
        return {
            'should_register': True,
            'reason': 'First model registration',
            'improvement_pct': 0.0,
            'decision': 'APPROVE'
        }

    # Compare with existing model
    new_metric = new_model['metric_value']
    current_metric = current_model['metric_value']
    
    if current_metric is None or current_metric == 0:
        print("‚ö†Ô∏è Current model has no valid metric, approving new model")
        return {
            'should_register': True,
            'reason': 'Current model has invalid metric',
            'improvement_pct': 0.0,
            'decision': 'APPROVE'
        }
    
    improvement = (current_metric - new_metric) / current_metric
    improvement_pct = improvement * 100

    print(f"\nüìä Comparison:")
    print(f"   New Model {METRIC_KEY.upper()}: {new_metric:.6f}")
    print(f"   Current Model {METRIC_KEY.upper()}: {current_metric:.6f}")
    print(f"   Improvement: {improvement_pct:.2f}%")
    print(f"   Threshold: {IMPROVEMENT_THRESHOLD * 100}%")

    if improvement >= IMPROVEMENT_THRESHOLD:
        print(f"\n‚úÖ APPROVED: Model improved by {improvement_pct:.2f}%")
        return {
            'should_register': True,
            'reason': f'Improvement: {improvement_pct:.2f}%',
            'improvement_pct': improvement_pct,
            'decision': 'APPROVE'
        }
    else:
        print(f"\n‚ùå REJECTED: Insufficient improvement ({improvement_pct:.2f}%)")
        return {
            'should_register': False,
            'reason': f'Insufficient improvement: {improvement_pct:.2f}%',
            'improvement_pct': improvement_pct,
            'decision': 'REJECT'
        }

# =============================================================================
# üíæ STEP 4: SAVE BEST MODEL METADATA
# =============================================================================
def save_best_model_metadata(model_info, evaluation_result):
    """Save best model metadata to Delta table for registration script"""
    print(f"\n{'='*70}")
    print("üìã STEP 4: Saving Best Model Metadata")
    print(f"{'='*70}")

    try:
        metadata = {
            "evaluation_timestamp": [datetime.now()],
            "run_id": [model_info['run_id']],
            "run_name": [model_info['run_name']],
            "model_uri": [model_info['model_uri']],
            "artifact_path": [model_info['artifact_path']],
            "metric_key": [model_info['metric_key']],
            "metric_value": [float(model_info['metric_value'])],
            "should_register": [bool(evaluation_result['should_register'])],
            "evaluation_reason": [str(evaluation_result['reason'])],
            "improvement_pct": [float(evaluation_result['improvement_pct'])],
            "model_name": [MODEL_NAME],
            "total_runs_evaluated": [int(model_info['total_runs'])],
            "params_json": [json.dumps(dict(model_info['params']))]
        }
        
        df = spark.createDataFrame(pd.DataFrame(metadata))
        
        # Overwrite latest evaluation result
        df.write.format("delta")\
            .mode("overwrite")\
            .option("overwriteSchema", "true")\
            .saveAsTable(BEST_MODEL_METADATA_TABLE)
        
        print(f"‚úÖ Metadata saved to: {BEST_MODEL_METADATA_TABLE}")
        print(f"   Run ID: {model_info['run_id']}")
        print(f"   Decision: {evaluation_result['decision']}")
        
        return True

    except Exception as e:
        print(f"‚ùå Failed to save metadata: {e}")
        traceback.print_exc()
        return False

# =============================================================================
# üìù STEP 5: LOG EVALUATION HISTORY
# =============================================================================
def log_evaluation_history(model_info, current_model, evaluation_result):
    """Log evaluation to history table"""
    print(f"\n{'='*70}")
    print("üìã STEP 5: Logging Evaluation History")
    print(f"{'='*70}")

    try:
        from pyspark.sql.types import (StructType, StructField, StringType, 
                                       DoubleType, BooleanType, LongType, TimestampType)
        
        schema = StructType([
            StructField("timestamp", TimestampType(), True),
            StructField("new_run_id", StringType(), True),
            StructField("new_run_name", StringType(), True),
            StructField("new_metric", DoubleType(), True),
            StructField("current_version", LongType(), True),
            StructField("current_metric", DoubleType(), True),
            StructField("current_alias", StringType(), True),
            StructField("should_promote", BooleanType(), True),
            StructField("promotion_reason", StringType(), True),
            StructField("improvement_pct", DoubleType(), True),
            StructField("promoted_to_staging", BooleanType(), True),
            StructField("promoted_version", LongType(), True),
            StructField("threshold_used", DoubleType(), True),
            StructField("total_runs_evaluated", LongType(), True),
            StructField("selection_method", StringType(), True)
        ])
        
        log_data = {
            "timestamp": [datetime.now()],
            "new_run_id": [model_info['run_id']],
            "new_run_name": [model_info['run_name']],
            "new_metric": [float(model_info['metric_value'])],
            "current_version": [int(current_model['version']) if current_model else None],
            "current_metric": [float(current_model['metric_value']) if current_model and current_model['metric_value'] else None],
            "current_alias": ['Staging' if current_model else None],
            "should_promote": [bool(evaluation_result['should_register'])],
            "promotion_reason": [str(evaluation_result['reason'])],
            "improvement_pct": [float(evaluation_result['improvement_pct'])],
            "promoted_to_staging": [False],
            "promoted_version": [None],
            "threshold_used": [float(IMPROVEMENT_THRESHOLD * 100)],
            "total_runs_evaluated": [int(model_info['total_runs'])],
            "selection_method": ["ALL-TIME BEST"]
        }
        
        df = spark.createDataFrame(pd.DataFrame(log_data), schema=schema)
        
        df.write.format("delta")\
            .mode("append")\
            .option("mergeSchema", "true")\
            .saveAsTable(EVALUATION_LOG_TABLE)
        
        print(f"‚úÖ History logged to: {EVALUATION_LOG_TABLE}")

    except Exception as e:
        print(f"‚ö†Ô∏è Failed to log history: {e}")
        print("   (Non-critical error - continuing)")

# =============================================================================
# üé¨ MAIN EXECUTION
# =============================================================================
def main():
    print(f"\n{'='*80}")
    print("üöÄ STARTING MODEL EVALUATION")
    print(f"{'='*80}")
    
    # Step 1: Find best model from experiment
    best_model = get_best_model_from_experiment()
    if not best_model:
        print("\n‚ùå EVALUATION FAILED - No valid models found")
        return False

    # Step 2: Get current registered model
    current_model = get_current_registered_model()
    
    # Step 3: Evaluate model
    evaluation_result = evaluate_model(best_model, current_model)

    # Step 4: Save metadata for registration script
    metadata_saved = save_best_model_metadata(best_model, evaluation_result)
    
    # Step 5: Log to history
    log_evaluation_history(best_model, current_model, evaluation_result)

    # Final Summary
    print("\n" + "=" * 80)
    print("‚úÖ MODEL EVALUATION COMPLETE")
    print("=" * 80)
    print(f"üìä Selected Model:")
    print(f"   Model Type: {MODEL_TYPE.upper()}")
    print(f"   Target Registry: {MODEL_NAME}")
    print(f"   Run ID: {best_model['run_id']}")
    print(f"   Run Name: {best_model['run_name']}")
    print(f"   {METRIC_KEY.upper()}: {best_model['metric_value']:.6f}")
    print(f"   Rank: #1 from {best_model['total_runs']} runs")
    print(f"\nüéØ Evaluation Decision: {evaluation_result['decision']}")
    print(f"   Reason: {evaluation_result['reason']}")
    print(f"   Should Register: {'YES ‚úÖ' if evaluation_result['should_register'] else 'NO ‚ùå'}")
    
    if metadata_saved:
        print(f"\nüì¶ Next Step:")
        print(f"   Run Model_Registration script to register approved model")
        print(f"   Metadata saved in: {BEST_MODEL_METADATA_TABLE}")
    
    print("=" * 80)
    
    # Save for workflow
    try:
        dbutils.jobs.taskValues.set(key="model_type", value=MODEL_TYPE)
        dbutils.jobs.taskValues.set(key="model_name", value=MODEL_NAME)
        dbutils.jobs.taskValues.set(key="should_register", value=evaluation_result['should_register'])
        print("‚úÖ Task values saved for workflow")
    except:
        print("‚ÑπÔ∏è Not running in workflow - skipping task values")
    
    return evaluation_result['should_register']

# Execute
if __name__ == "__main__":
    approved = main()
    print(f"\nüéØ MODEL APPROVAL STATUS: {'APPROVED ‚úÖ' if approved else 'REJECTED ‚ùå'}")