In [None]:
# Databricks notebook source
# =============================================================================
# üöÄ UAT STAGING PROMOTION - CONFIG DRIVEN (FIXED)
# =============================================================================
# Purpose: Promote latest registered model to Staging alias
# Now reads from pipeline_config.yml - No hardcoding!
# Prerequisites: Run Model_Registration script first
# =============================================================================

import mlflow
from mlflow.tracking import MlflowClient
import time
import os
import yaml
import sys
import traceback

print("=" * 80)
print("üöÄ UAT STAGING PROMOTION (CONFIG-DRIVEN)")
print("=" * 80)

# =============================================================================
# ‚úÖ LOAD PIPELINE CONFIGURATION
# =============================================================================
print("\nüìã Loading pipeline configuration from pipeline_config.yml...")

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    
    # Extract configuration
    MODEL_TYPE = pipeline_cfg["model"]["type"]
    UC_CATALOG = pipeline_cfg["model"]["catalog"]
    UC_SCHEMA = pipeline_cfg["model"]["schema"]
    BASE_NAME = pipeline_cfg["model"]["base_name"]
    
    # Auto-generate model name
    MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}_uc2"
    
    STAGING_ALIAS = pipeline_cfg["aliases"]["staging"]
    METRIC_KEY = pipeline_cfg["metrics"]["primary_metric"]
    
    TOL = 1e-6  # float tolerance for comparison
    
    print(f"‚úÖ Configuration loaded successfully!")
    print(f"\nüìä Configuration Details:")
    print(f"   Model Type: {MODEL_TYPE.upper()}")
    print(f"   Model Name: {MODEL_NAME}")
    print(f"   Staging Alias: @{STAGING_ALIAS}")
    print(f"   Metric Key: {METRIC_KEY}")
    
except FileNotFoundError:
    print("‚ùå ERROR: pipeline_config.yml not found!")
    print("üí° Please create pipeline_config.yml in the same directory")
    sys.exit(1)
except Exception as e:
    print(f"‚ùå ERROR loading configuration: {e}")
    traceback.print_exc()
    sys.exit(1)

print("=" * 80)

# =============================================================================
# ‚úÖ INITIALIZE MLFLOW
# =============================================================================
try:
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
        print("\n‚úÖ Using Unity Catalog Registry")
    client = MlflowClient()
except Exception as e:
    print(f"‚ùå Failed to initialize MLflow client: {e}")
    raise e

# =============================================================================
# ‚úÖ HELPER: WAIT UNTIL MODEL VERSION IS READY
# =============================================================================
def wait_until_ready(client, model_name, version, timeout=300):
    """Wait for model version to become READY"""
    start = time.time()
    while time.time() - start < timeout:
        mv = client.get_model_version(model_name, version)
        status = mv.status
        print(f"‚è≥ Model v{version} status = {status}")

        if status == "READY":
            return True
        elif status == "FAILED_REGISTRATION":
            print("‚ùå Model registration failed.")
            return False
        
        time.sleep(5)

    print("‚è∞ Timeout: Model is still not READY")
    return False

# =============================================================================
# ‚úÖ HELPER: GET METRIC FROM RUN
# =============================================================================
def get_metric_from_run(client, model_name, version, run_id):
    """Try to get metric from run or model version tags"""
    metric_value = None
    
    # Method 1: Try to get from run metrics
    try:
        run = client.get_run(run_id)
        metric_value = run.data.metrics.get(METRIC_KEY, None)
        if metric_value is not None:
            print(f"  ‚úì Metric found in run metrics: {metric_value:.6f}")
            return metric_value
    except Exception as e:
        print(f"  ‚ö† Could not fetch run {run_id}: {e}")
    
    # Method 2: Try to get from model version tags
    try:
        mv = client.get_model_version(model_name, version)
        metric_tag = mv.tags.get("metric_rmse", None)
        if metric_tag:
            metric_value = float(metric_tag)
            print(f"  ‚úì Metric found in model tags: {metric_value:.6f}")
            return metric_value
    except Exception as e:
        print(f"  ‚ö† Could not fetch metric from tags: {e}")
    
    print(f"  ‚ö† No metric found for version {version}")
    return None

# =============================================================================
# ‚úÖ STEP 1: FIND LATEST MODEL VERSION
# =============================================================================
print(f"\n{'='*70}")
print("üìã STEP 1: Finding Latest Model Version")
print(f"{'='*70}")

try:
    model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")

    if not model_versions:
        print(f"‚ùå No versions found for model: {MODEL_NAME}")
        print("\nüí° Please run Model_Registration script first")
        sys.exit(1)

    latest_version = sorted(model_versions, key=lambda m: int(m.version), reverse=True)[0]
    new_version = latest_version.version
    new_run_id = latest_version.run_id

    print(f"\n‚úÖ Latest Registered Model Version: v{new_version}")
    print(f"   Run ID: {new_run_id}")
    print(f"   Status: {latest_version.status}")

    # Get metric for new version
    print(f"\nüîç Fetching metric for new version v{new_version}...")
    new_metric = get_metric_from_run(client, MODEL_NAME, new_version, new_run_id)

    if new_metric is None:
        print(f"‚ùå ERROR: Could not find {METRIC_KEY} for new version v{new_version}")
        print("   This version cannot be evaluated. Exiting.")
        sys.exit(1)

    print(f"‚úÖ New Model {METRIC_KEY}: {new_metric:.6f}")

except Exception as e:
    print(f"‚ùå Error finding latest version: {e}")
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# ‚úÖ STEP 2: FIND EXISTING STAGING MODEL
# =============================================================================
print(f"\n{'='*70}")
print("üìã STEP 2: Checking Current STAGING Model")
print(f"{'='*70}")

try:
    staging_version = client.get_model_version_by_alias(MODEL_NAME, STAGING_ALIAS)
    old_version = staging_version.version
    old_run_id = staging_version.run_id

    print(f"\nüìå Current STAGING Version: v{old_version}")
    print(f"   Run ID: {old_run_id}")
    
    # Get metric for old version
    print(f"\nüîç Fetching metric for staging version v{old_version}...")
    old_metric = get_metric_from_run(client, MODEL_NAME, old_version, old_run_id)
    
    if old_metric is not None:
        print(f"üìå Current STAGING {METRIC_KEY}: {old_metric:.6f}")
    else:
        print(f"‚ö†Ô∏è WARNING: Could not find {METRIC_KEY} for staging v{old_version}")
        print("   Will promote new model by default.")

except Exception as e:
    print(f"\n‚ÑπÔ∏è No current STAGING model found: {e}")
    print("   Will promote latest model to staging.")
    staging_version = None
    old_metric = None

# =============================================================================
# ‚úÖ STEP 3: COMPARE METRICS (LOWER RMSE = BETTER)
# =============================================================================
print(f"\n{'='*70}")
print("üìã STEP 3: Metric Comparison")
print(f"{'='*70}")

promote = False
promotion_reason = ""

if staging_version is None:
    promote = True
    promotion_reason = "No existing staging model"
    print(f"\nüü¢ DECISION: PROMOTE")
    print(f"   Reason: {promotion_reason}")
    
elif old_metric is None:
    promote = True
    promotion_reason = "Staging model has no metric (old version)"
    print(f"\nüü¢ DECISION: PROMOTE")
    print(f"   Reason: {promotion_reason}")
    
else:
    # Both metrics exist - compare them
    print(f"\nüìä Metric Comparison:")
    print(f"   New Model (v{new_version}):     {METRIC_KEY} = {new_metric:.6f}")
    print(f"   Staging Model (v{old_version}): {METRIC_KEY} = {old_metric:.6f}")
    print(f"   Improvement:                     {old_metric - new_metric:.6f}")
    
    if new_metric < old_metric - TOL:
        promote = True
        improvement_pct = ((old_metric - new_metric) / old_metric) * 100
        promotion_reason = f"New model is better (improvement: {improvement_pct:.2f}%)"
        print(f"\nüü¢ DECISION: PROMOTE")
        print(f"   Reason: {promotion_reason}")
        
    elif abs(new_metric - old_metric) <= TOL:
        print(f"\nüü° DECISION: NO PROMOTION")
        print(f"   Reason: New model performance is same as staging (within tolerance)")
        print(f"   Keeping existing staging version v{old_version}")
        
    else:
        print(f"\n‚õî DECISION: NO PROMOTION")
        print(f"   Reason: New model is WORSE than staging")
        print(f"   Degradation: {(new_metric - old_metric):.6f}")
        print(f"   Keeping existing staging version v{old_version}")

# =============================================================================
# ‚úÖ STEP 4: PROMOTE USING ALIAS
# =============================================================================
if promote:
    print(f"\n{'='*70}")
    print("üìã STEP 4: Promoting Model to STAGING")
    print(f"{'='*70}")
    
    print(f"\n‚è≥ Waiting for model v{new_version} to become READY...")
    if wait_until_ready(client, MODEL_NAME, new_version):
        
        try:
            # Set the alias
            client.set_registered_model_alias(
                name=MODEL_NAME,
                alias=STAGING_ALIAS,
                version=new_version
            )

            print(f"\n{'='*70}")
            print("‚úÖ‚úÖ PROMOTION SUCCESSFUL ‚úÖ‚úÖ")
            print(f"{'='*70}")
            print(f"   Model: {MODEL_NAME}")
            print(f"   Model Type: {MODEL_TYPE.upper()}")
            print(f"   New STAGING Version: v{new_version}")
            print(f"   {METRIC_KEY}: {new_metric:.6f}")
            print(f"   Reason: {promotion_reason}")
            print(f"{'='*70}\n")
            
            # Save for workflow
            try:
                dbutils.jobs.taskValues.set(key="staging_version", value=new_version)
                dbutils.jobs.taskValues.set(key="staging_metric", value=new_metric)
                print("‚úÖ Task values saved for workflow")
            except:
                print("‚ÑπÔ∏è Not running in workflow - skipping task values")
            
        except Exception as e:
            print(f"\n‚ùå Failed to set alias: {e}")
            traceback.print_exc()
            sys.exit(1)
    else:
        print("\n‚ùå Promotion failed: Model did not become READY in time.")
        sys.exit(1)
        
else:
    print(f"\n{'='*70}")
    print("‚úÖ STAGING UNCHANGED")
    print(f"{'='*70}")
    if staging_version:
        print(f"   Current STAGING Version: v{old_version}")
        if old_metric:
            print(f"   {METRIC_KEY}: {old_metric:.6f}")
    print(f"{'='*70}\n")
    
    # Save for workflow even if no promotion
    try:
        dbutils.jobs.taskValues.set(key="staging_version", value=old_version if staging_version else None)
        dbutils.jobs.taskValues.set(key="staging_metric", value=old_metric if old_metric else None)
    except:
        pass

print("\nüìå Next Step:")
print("   Run 05_uat_inference.py to validate the staging model")
print("=" * 80)