In [None]:
# Databricks notebook source
# =============================================================================
# üèÜ MODEL REGISTRATION SCRIPT - CONFIG DRIVEN (FIXED)
# =============================================================================
# Purpose: Register approved models from evaluation pipeline
# Now reads from pipeline_config.yml - No hardcoding!
# Prerequisites: Run model_evaluation_FIXED.py first
# =============================================================================

import mlflow
from mlflow.tracking import MlflowClient
import sys
import os
import yaml
import requests
import traceback
from typing import Dict, Optional, Any
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from delta.tables import DeltaTable
from IPython import get_ipython

print("=" * 80)
print("üèÜ MODEL REGISTRATION SYSTEM (CONFIG-DRIVEN)")
print("=" * 80)

# =============================================================================
# ‚úÖ LOAD PIPELINE CONFIGURATION
# =============================================================================
print("\nüìã Loading pipeline configuration from pipeline_config.yml...")

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    
    print(f"‚úÖ Pipeline configuration loaded successfully!")
    
except FileNotFoundError:
    print("‚ùå ERROR: pipeline_config.yml not found!")
    print("üí° Please create pipeline_config.yml in the same directory")
    sys.exit(1)
except Exception as e:
    print(f"‚ùå ERROR loading configuration: {e}")
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# ‚úÖ CONFIGURATION CLASS (NOW DYNAMIC!)
# =============================================================================
class Config:
    """Centralized configuration management - reads from pipeline_config.yml"""
    
    def __init__(self):
        # Extract from pipeline config
        MODEL_TYPE = pipeline_cfg["model"]["type"]
        UC_CATALOG = pipeline_cfg["model"]["catalog"]
        UC_SCHEMA = pipeline_cfg["model"]["schema"]
        BASE_NAME = pipeline_cfg["model"]["base_name"]
        
        # Unity Catalog Configuration (dynamic)
        self.UC_CATALOG = UC_CATALOG
        self.UC_SCHEMA = UC_SCHEMA
        self.MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}_uc2"
        
        # Aliases
        self.STAGING_ALIAS = pipeline_cfg["aliases"]["staging"]
        self.PRODUCTION_ALIAS = pipeline_cfg["aliases"]["production"]
        
        # Delta Tables
        self.BEST_MODEL_METADATA_TABLE = pipeline_cfg["tables"]["best_model_metadata"]
        self.EVALUATION_LOG_TABLE = pipeline_cfg["tables"]["evaluation_log"]
        
        # Model Configuration
        self.ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]
        self.METRIC_KEY = pipeline_cfg["metrics"]["primary_metric"]
        self.TOL = 1e-6
        
        # Slack Configuration
        self.SLACK_WEBHOOK_URL = self._get_slack_webhook()
        
        # Store model type for reference
        self.MODEL_TYPE = MODEL_TYPE
        
    def _get_slack_webhook(self) -> Optional[str]:
        """Safely retrieve Slack webhook URL"""
        scopes = ["shared-scope", "dev-scope"]
        for scope in scopes:
            try:
                webhook = dbutils.secrets.get(scope, "SLACK_WEBHOOK_URL")
                if webhook and webhook.strip():
                    print(f"‚úÖ Slack webhook configured from scope '{scope}'")
                    return webhook
            except Exception:
                continue
        print("‚ÑπÔ∏è Slack notifications disabled")
        return None

# Initialize configuration
config = Config()

print(f"\nüìä Configuration Details:")
print(f"   Model Type: {config.MODEL_TYPE.upper()}")
print(f"   Model Name: {config.MODEL_NAME}")
print(f"   Staging Alias: @{config.STAGING_ALIAS}")
print(f"   Production Alias: @{config.PRODUCTION_ALIAS}")
print(f"   Metadata Table: {config.BEST_MODEL_METADATA_TABLE}")
print(f"   Log Table: {config.EVALUATION_LOG_TABLE}")
print(f"   Metric: {config.METRIC_KEY}")
print("=" * 80)

# =============================================================================
# ‚úÖ SLACK NOTIFICATION HELPER
# =============================================================================
class SlackNotifier:
    """Enhanced Slack notification handler"""
    
    def __init__(self, webhook_url: Optional[str]):
        self.webhook_url = webhook_url
        self.enabled = webhook_url is not None
        
    def send(self, message: str, level: str = "info") -> bool:
        """Send Slack notification with error handling"""
        if not self.enabled:
            print(f"üì¢ [SLACK DISABLED] {message}")
            return False
            
        emoji_map = {
            "info": "‚ÑπÔ∏è",
            "success": "‚úÖ",
            "warning": "‚ö†Ô∏è",
            "error": "‚ùå"
        }
        
        formatted_message = f"{emoji_map.get(level, '‚ÑπÔ∏è')} {message}"
        payload = {"text": formatted_message}
        
        try:
            response = requests.post(
                self.webhook_url, 
                json=payload,
                timeout=5
            )
            if response.status_code == 200:
                print(f"üì¢ Slack notification sent: {level}")
                return True
            else:
                print(f"‚ö†Ô∏è Slack error: {response.status_code}")
                return False
        except Exception as e:
            print(f"‚ùå Slack notification failed: {e}")
            return False

# Initialize Slack notifier
slack = SlackNotifier(config.SLACK_WEBHOOK_URL)

# =============================================================================
# ‚úÖ INITIALIZATION
# =============================================================================
try:
    spark = SparkSession.builder.appName("ModelRegistration").getOrCreate()
    mlflow.set_tracking_uri("databricks")
    mlflow.set_registry_uri("databricks-uc")
    client = MlflowClient()
    print("\n‚úÖ MLflow and Spark initialized")

except Exception as e:
    print(f"‚ùå Initialization failed: {e}")
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# üîß HELPER: GET MODEL ALIASES SAFELY
# =============================================================================
def get_model_aliases_safe(model_name: str, version: int) -> list:
    """Safely get aliases for a model version"""
    try:
        common_aliases = ['production', 'Staging', 'champion', 'baseline']
        found_aliases = []
        for alias in common_aliases:
            try:
                alias_version = client.get_model_version_by_alias(model_name, alias)
                if alias_version and str(alias_version.version) == str(version):
                    found_aliases.append(alias)
            except Exception:
                continue
        return found_aliases
    except Exception:
        return []

# =============================================================================
# üìã STEP 1: READ EVALUATION RESULTS
# =============================================================================
def get_evaluation_results() -> Optional[Dict]:
    """Read latest evaluation results from Delta table"""
    print(f"\n{'='*70}")
    print("üìã STEP 1: Reading Evaluation Results")
    print(f"{'='*70}")

    try:
        # Check if table exists
        tables = spark.catalog.listTables("default")
        table_names = [t.name for t in tables]
        
        if "best_model_metadata" not in table_names:
            print(f"‚ùå Table '{config.BEST_MODEL_METADATA_TABLE}' not found!")
            print("\nüí° Please run model_evaluation_FIXED.py first")
            return None
        
        # Read latest evaluation
        df = spark.read.format("delta").table(config.BEST_MODEL_METADATA_TABLE)
        
        if df.count() == 0:
            print("‚ùå No evaluation results found in table!")
            print("\nüí° Please run model_evaluation_FIXED.py first")
            return None
        
        # Get latest evaluation (most recent timestamp)
        latest = df.orderBy(df.evaluation_timestamp.desc()).first()
        
        print(f"‚úÖ Evaluation Results Found:")
        print(f"   Evaluated At: {latest.evaluation_timestamp}")
        print(f"   Model Type: {config.MODEL_TYPE.upper()}")
        print(f"   Target Registry: {config.MODEL_NAME}")
        print(f"   Run ID: {latest.run_id}")
        print(f"   Run Name: {latest.run_name}")
        print(f"   Model URI: {latest.model_uri}")
        print(f"   Metric ({latest.metric_key}): {latest.metric_value:.6f}")
        print(f"   Should Register: {'YES ‚úÖ' if latest.should_register else 'NO ‚ùå'}")
        print(f"   Reason: {latest.evaluation_reason}")
        print(f"   Improvement: {latest.improvement_pct:.2f}%")
        print(f"   Total Runs Evaluated: {latest.total_runs_evaluated}")
        
        return {
            'run_id': latest.run_id,
            'run_name': latest.run_name,
            'model_uri': latest.model_uri,
            'artifact_path': latest.artifact_path,
            'metric_key': latest.metric_key,
            'metric_value': float(latest.metric_value),
            'should_register': bool(latest.should_register),
            'reason': latest.evaluation_reason,
            'improvement_pct': float(latest.improvement_pct),
            'total_runs': int(latest.total_runs_evaluated),
            'evaluation_time': latest.evaluation_timestamp,
            'params_json': latest.params_json if hasattr(latest, 'params_json') else "{}"
        }

    except Exception as e:
        print(f"‚ùå Failed to read evaluation results: {e}")
        traceback.print_exc()
        return None

# =============================================================================
# üîç STEP 2: CHECK FOR DUPLICATE VERSIONS
# =============================================================================
def check_duplicate(eval_results: Dict) -> Optional[Any]:
    """Check if model with same run_id already exists"""
    print(f"\n{'='*70}")
    print("üìã STEP 2: Checking for Duplicates")
    print(f"{'='*70}")

    try:
        mv_list = client.search_model_versions(f"name = '{config.MODEL_NAME}'")
        versions_list = list(mv_list)
        
    except Exception:
        print(f"‚ÑπÔ∏è No existing model versions (first registration)")
        return None
    
    if not versions_list:
        print("‚ÑπÔ∏è No existing versions found (first registration)")
        return None
    
    print(f"‚úÖ Found {len(versions_list)} existing version(s)")
    
    new_run_id = eval_results['run_id']
    new_metric = eval_results['metric_value']
    
    for mv in versions_list:
        try:
            if mv.run_id == new_run_id:
                version_aliases = get_model_aliases_safe(config.MODEL_NAME, mv.version)
                aliases_str = ', '.join(version_aliases) if version_aliases else 'None'
                
                print(f"\n‚ö†Ô∏è DUPLICATE DETECTED!")
                print(f"   Existing Version: v{mv.version}")
                print(f"   Run ID: {mv.run_id}")
                print(f"   Aliases: {aliases_str}")
                print(f"\n   ‚Üí Model already registered, skipping registration")
                
                slack.send(
                    f"‚ö†Ô∏è Duplicate detected ‚Äî using existing version *v{mv.version}* "
                    f"for `{config.MODEL_NAME}`",
                    level="warning"
                )
                return mv
            
            try:
                run = client.get_run(mv.run_id)
                old_metric = run.data.metrics.get(config.METRIC_KEY)
                
                if old_metric and abs(old_metric - new_metric) <= config.TOL:
                    print(f"\n‚ö†Ô∏è Similar model found!")
                    print(f"   Version: v{mv.version}")
                    print(f"   Metric difference: {abs(old_metric - new_metric):.8f}")
                    print(f"   (Within tolerance: {config.TOL})")
            except Exception:
                pass
                
        except Exception as e:
            print(f"‚ö†Ô∏è Error checking version {mv.version}: {e}")
            continue
    
    print("\n‚úÖ No duplicates found - proceeding with registration")
    return None

# =============================================================================
# üöÄ STEP 3: REGISTER MODEL TO UNITY CATALOG
# =============================================================================
def register_model(eval_results: Dict) -> Optional[Any]:
    """Register the approved model to Unity Catalog"""
    print(f"\n{'='*70}")
    print("üìã STEP 3: Registering Model to Unity Catalog")
    print(f"{'='*70}")

    if not eval_results['should_register']:
        print("‚ùå Model NOT APPROVED for registration")
        slack.send(
            f"‚è≠Ô∏è Model registration skipped for `{config.MODEL_NAME}`\n"
            f"Reason: {eval_results['reason']}",
            level="warning"
        )
        return None

    duplicate = check_duplicate(eval_results)
    if duplicate:
        return duplicate

    try:
        print(f"\n‚è≥ Registering model to: {config.MODEL_NAME}")
        print(f"   Model URI: {eval_results['model_uri']}")
        
        new_version = mlflow.register_model(eval_results['model_uri'], config.MODEL_NAME)
        
        print(f"\n‚úÖ MODEL REGISTERED SUCCESSFULLY!")
        print(f"   Model: {config.MODEL_NAME}")
        print(f"   Version: v{new_version.version}")
        print(f"   Model Type: {config.MODEL_TYPE.upper()}")
        
        slack.send(
            f"‚úÖ Model *{config.MODEL_NAME}* registered as version *v{new_version.version}*\n"
            f"Model Type: {config.MODEL_TYPE.upper()}",
            level="success"
        )
        return new_version

    except Exception as e:
        print(f"‚ùå Registration failed: {e}")
        traceback.print_exc()
        slack.send(
            f"‚ùå Model registration failed for `{config.MODEL_NAME}`: {e}",
            level="error"
        )
        return None

# =============================================================================
# üè∑Ô∏è STEP 4: SET STAGING ALIAS & ADD TAGS
# =============================================================================
def set_staging_alias_and_tags(version_number: int, eval_results: Dict) -> bool:
    """Set staging alias and add metadata tags"""
    print(f"\n{'='*70}")
    print("üìã STEP 4: Setting Staging Alias and Tags")
    print(f"{'='*70}")
    
    try:
        # Set staging alias
        print(f"   Setting @{config.STAGING_ALIAS} alias...")
        client.set_registered_model_alias(
            config.MODEL_NAME, 
            config.STAGING_ALIAS, 
            version_number
        )
        print(f"   ‚úì Alias set: @{config.STAGING_ALIAS}")
        
        # Add tags
        print(f"   Adding metadata tags...")
        tags = {
            "model_type": config.MODEL_TYPE,
            "registered_from": "registration_pipeline",
            "evaluation_reason": eval_results['reason'],
            "improvement_pct": f"{eval_results['improvement_pct']:.2f}",
            "registration_timestamp": datetime.now().isoformat(),
            "metric_rmse": str(eval_results['metric_value']),
            "source_run_id": eval_results['run_id'],
            "source_run_name": eval_results['run_name'],
            "total_runs_evaluated": str(eval_results['total_runs']),
            "artifact_path": eval_results['artifact_path'],
            "evaluation_timestamp": str(eval_results['evaluation_time'])
        }
        
        for key, value in tags.items():
            try:
                client.set_model_version_tag(config.MODEL_NAME, version_number, key, value)
            except Exception as e:
                print(f"   ‚ö†Ô∏è Failed to set tag '{key}': {e}")
                continue
        
        print(f"   ‚úì Tags added successfully")
        return True
        
    except Exception as e:
        print(f"‚ùå Failed to set alias/tags: {e}")
        traceback.print_exc()
        return False

# =============================================================================
# üìù STEP 5: UPDATE EVALUATION LOG
# =============================================================================
def update_evaluation_log(version_number: int, eval_results: Dict) -> bool:
    """Update evaluation log with registration info"""
    print(f"\n{'='*70}")
    print("üìã STEP 5: Updating Evaluation Log")
    print(f"{'='*70}")
    
    try:
        delta_table = DeltaTable.forName(spark, config.EVALUATION_LOG_TABLE)
        delta_table.update(
            condition=f"new_run_id = '{eval_results['run_id']}'",
            set={
                "promoted_to_staging": True,
                "promoted_version": version_number
            }
        )
        print(f"‚úÖ Evaluation log updated")
        return True
        
    except Exception as e:
        print(f"‚ö†Ô∏è Failed to update log: {e}")
        traceback.print_exc()
        return False

# =============================================================================
# üìä STEP 6: DISPLAY REGISTRATION SUMMARY
# =============================================================================
def display_summary(eval_results: Dict, version_number: int) -> None:
    """Display registration summary"""
    print(f"\n{'='*80}")
    print("‚úÖ‚úÖ MODEL REGISTRATION COMPLETE ‚úÖ‚úÖ")
    print(f"{'='*80}")
    print(f"\nüìä Source Model:")
    print(f"   Model Type: {config.MODEL_TYPE.upper()}")
    print(f"   Run ID: {eval_results['run_id']}")
    print(f"   Run Name: {eval_results['run_name']}")
    print(f"   {config.METRIC_KEY.upper()}: {eval_results['metric_value']:.6f}")
    print(f"\nüèÜ Registered Model:")
    print(f"   Registry: {config.MODEL_NAME}")
    print(f"   Version: v{version_number}")
    print(f"   Alias: @{config.STAGING_ALIAS}")
    print(f"\nüìå Next Steps:")
    print(f"   1. Run UAT Staging Promotion")
    print(f"   2. Run UAT Inference")
    print(f"   3. If UAT passes ‚Üí Production Promotion")
    print("=" * 80)

# =============================================================================
# üé¨ MAIN EXECUTION
# =============================================================================
def exit_notebook_friendly(code=0):
    """Exit safely in notebooks, sys.exit in scripts"""
    ip = get_ipython()
    if ip is not None:
        return
    else:
        sys.exit(code)

def main():
    """Main registration pipeline"""
    try:
        # Step 1: Read evaluation results
        eval_results = get_evaluation_results()
        if not eval_results:
            print("\n‚ùå No evaluation results found")
            exit_notebook_friendly(1)
        
        # Step 2: Check if approved
        if not eval_results['should_register']:
            print("\n‚è≠Ô∏è REGISTRATION SKIPPED (model not approved)")
            print(f"   Reason: {eval_results['reason']}")
            exit_notebook_friendly(0)
        
        # Step 3: Register model
        new_version = register_model(eval_results)
        if not new_version:
            print("\n‚ùå Registration failed")
            exit_notebook_friendly(1)
        
        # Step 4: Set alias and tags
        set_staging_alias_and_tags(new_version.version, eval_results)
        
        # Step 5: Update logs
        update_evaluation_log(new_version.version, eval_results)
        
        # Step 6: Display summary
        display_summary(eval_results, new_version.version)
        
        # Save for workflow
        try:
            dbutils.jobs.taskValues.set(key="model_type", value=config.MODEL_TYPE)
            dbutils.jobs.taskValues.set(key="model_name", value=config.MODEL_NAME)
            dbutils.jobs.taskValues.set(key="model_version", value=new_version.version)
            print("‚úÖ Task values saved for workflow")
        except:
            print("‚ÑπÔ∏è Not running in workflow - skipping task values")
        
        exit_notebook_friendly(0)
        
    except Exception as e:
        print(f"\n‚ùå Registration pipeline failed: {e}")
        traceback.print_exc()
        exit_notebook_friendly(1)

# Execute
if __name__ == "__main__":
    main()