In [None]:
# Databricks notebook source
import mlflow
from mlflow.tracking import MlflowClient
import os
import time
import sys

print("=" * 70)
print("PRODUCTION MODEL PROMOTION")
print("=" * 70)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
MODEL_NAME = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_model_uc"

STAGING_ALIAS = "staging"
PRODUCTION_ALIAS = "production"

# Quality thresholds for production promotion
MIN_R2_FOR_PROD = 0.80
MAX_RMSE_FOR_PROD = 50000.0  # Maximum acceptable RMSE

# =============================================================================
# MLflow Setup
# =============================================================================
try:
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
        print("MLflow configured for Unity Catalog")
    
    client = MlflowClient()
    print("MLflow client initialized")
except Exception as e:
    print(f"Error initializing MLflow: {e}")
    sys.exit(1)

# =============================================================================
# VALIDATE STAGING MODEL BEFORE PROMOTION
# =============================================================================
print(f"\nValidating Staging model before promotion...")
print(f"Model: {MODEL_NAME}")

try:
    # Get all model versions
    model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")
    
    if not model_versions:
        print(f"Error: No versions found for model {MODEL_NAME}")
        sys.exit(1)
    
    # Find version with Staging alias
    staging_version = None
    for version in model_versions:
        # Get full version details with aliases
        full_version = client.get_model_version(MODEL_NAME, version.version)
        
        if STAGING_ALIAS in full_version.aliases:
            staging_version = full_version
            break
    
    if not staging_version:
        print(f"Error: No model version found with '{STAGING_ALIAS}' alias")
        print("Please run UAT pipeline first to set Staging alias")
        sys.exit(1)
    
    print(f"\nFound Staging Model:")
    print(f"  Version: {staging_version.version}")
    print(f"  Status: {staging_version.status}")
    print(f"  Run ID: {staging_version.run_id}")
    print(f"  Aliases: {', '.join(staging_version.aliases)}")
    
    # Get model metrics from run
    run = client.get_run(staging_version.run_id)
    
    # ================= CORRECTED: Use UAT training script metric keys =================
    r2_score = run.data.metrics.get('test_r2_score', 0.0)
    rmse = run.data.metrics.get('test_rmse', float('inf'))
    cv_rmse = run.data.metrics.get('best_cv_rmse', 0.0)
    # ================================================================================

    print(f"\nModel Training Metrics:")
    print(f"  Test R² Score: {r2_score:.4f}")
    print(f"  Test RMSE: {rmse:,.2f}")
    print(f"  CV RMSE: {cv_rmse:,.2f}")
    
    # Display training parameters
    print(f"\nModel Parameters:")
    params_to_show = ['best_n_estimators', 'best_max_depth', 'best_min_samples_split', 'best_min_samples_leaf']
    for param in params_to_show:
        value = run.data.params.get(param, 'N/A')
        print(f"  {param}: {value}")
    
    # Quality checks
    print(f"\nProduction Quality Checks:")
    
    quality_checks = []
    all_passed = True
    
    if r2_score >= MIN_R2_FOR_PROD:
        quality_checks.append(f"  ✓ PASS: R² {r2_score:.4f} >= {MIN_R2_FOR_PROD}")
    else:
        quality_checks.append(f"  ✗ FAIL: R² {r2_score:.4f} < {MIN_R2_FOR_PROD}")
        all_passed = False
    
    if rmse <= MAX_RMSE_FOR_PROD:
        quality_checks.append(f"  ✓ PASS: RMSE ${rmse:,.2f} <= {MAX_RMSE_FOR_PROD:,.2f}")
    else:
        quality_checks.append(f"  ✗ FAIL: RMSE ${rmse:,.2f} > {MAX_RMSE_FOR_PROD:,.2f}")
        all_passed = False
    
    # Print results
    for check in quality_checks:
        print(check)
    
    # Check if all passed
    if not all_passed:
        print(f"\n{'=' * 70}")
        print("PROMOTION DENIED")
        print(f"{'=' * 70}")
        print("Model does not meet production quality standards")
        print(f"\nRequired:")
        print(f"  - R² Score >= {MIN_R2_FOR_PROD}")
        print(f"  - RMSE <= ${MAX_RMSE_FOR_PROD:,.2f}")
        print(f"\nActual:")
        print(f"  - R² Score: {r2_score:.4f}")
        print(f"  - RMSE: ${rmse:,.2f}")
        print(f"\nPlease retrain model with better parameters")
        print(f"{'=' * 70}")
        sys.exit(1)
    
    print(f"\n✓ Quality validation PASSED - Model eligible for production")
    
except Exception as e:
    print(f"Error during validation: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# WAIT FOR MODEL TO BE READY
# =============================================================================
print(f"\nChecking model readiness...")

max_wait_time = 300
elapsed_time = 0
check_interval = 5

while elapsed_time < max_wait_time:
    current_version = client.get_model_version(MODEL_NAME, staging_version.version)
    status = current_version.status
    
    if status == "READY":
        print(f"✓ Model version {staging_version.version} is READY")
        break
    elif status == "FAILED_REGISTRATION":
        print(f"✗ Error: Model registration failed")
        sys.exit(1)
    else:
        print(f"  Status: {status} - waiting {check_interval}s...")
        time.sleep(check_interval)
        elapsed_time += check_interval

if elapsed_time >= max_wait_time:
    print(f"✗ Timeout: Model not ready after {max_wait_time}s")
    sys.exit(1)

# =============================================================================
# PROMOTE TO PRODUCTION
# =============================================================================
print(f"\n{'=' * 70}")
print("PROMOTING MODEL TO PRODUCTION")
print(f"{'=' * 70}")

try:
    # Check if Production alias already exists on another version
    prod_version_exists = False
    old_prod_version = None
    
    for version in model_versions:
        full_version = client.get_model_version(MODEL_NAME, version.version)
        
        if PRODUCTION_ALIAS in full_version.aliases and full_version.version != staging_version.version:
            print(f"\nℹ Previous production version found: {full_version.version}")
            print(f"  Will move '{PRODUCTION_ALIAS}' alias to version {staging_version.version}")
            prod_version_exists = True
            old_prod_version = full_version.version
            break
    
    # Set Production alias on staging version
    print(f"\nSetting '{PRODUCTION_ALIAS}' alias on version {staging_version.version}...")
    
    client.set_registered_model_alias(
        MODEL_NAME, 
        PRODUCTION_ALIAS, 
        staging_version.version
    )
    
    print(f"✓ Production alias set successfully")
    
    # Verify the alias was set
    time.sleep(2)  # Brief wait for consistency
    
    updated_version = client.get_model_version(MODEL_NAME, staging_version.version)
    
    if PRODUCTION_ALIAS in updated_version.aliases:
        print(f"\n{'=' * 70}")
        print("✅ PROMOTION SUCCESSFUL")
        print(f"{'=' * 70}")
        print(f"\nModel Details:")
        print(f"  Version: {staging_version.version}")
        print(f"  Aliases: {', '.join(updated_version.aliases)}")
        print(f"  Status: PRODUCTION READY")
        print(f"\nModel Performance:")
        print(f"  R² Score: {r2_score:.4f}")
        print(f"  RMSE: ${rmse:,.2f}")
        print(f"  CV RMSE: ${cv_rmse:,.2f}")
        print(f"\nModel URI:")
        print(f"  models:/{MODEL_NAME}@{PRODUCTION_ALIAS}")
        
        if old_prod_version:
            print(f"\nVersion History:")
            print(f"  Previous Production: v{old_prod_version}")
            print(f"  New Production: v{staging_version.version}")
        
        print(f"{'=' * 70}")
    else:
        print(f"⚠ Warning: Alias may not have been set properly")
        print(f"Please verify in MLflow UI")
    
    # Log promotion metadata
    print(f"\nPromotion Metadata:")
    print(f"  Promoted By: Automated Pipeline")
    print(f"  Promotion Time: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"  Source Alias: {STAGING_ALIAS}")
    print(f"  Target Alias: {PRODUCTION_ALIAS}")
    print(f"  Model Quality: R²={r2_score:.4f}, RMSE=${rmse:,.2f}")
    print(f"  Run ID: {staging_version.run_id}")
    
except Exception as e:
    print(f"\n✗ Error during promotion: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# FINAL MESSAGE
# =============================================================================
print(f"\n{'=' * 70}")
print("NEXT STEPS:")
print(f"{'=' * 70}")
print("1. Create/Update Serving Endpoint with Production model")
print("2. Run production inference tests")
print("3. Monitor model performance in production")
print("4. Set up alerts for model drift")
print(f"{'=' * 70}")

# Success exit for pipeline
try:
    dbutils.notebook.exit("PROMOTION_SUCCESS")
except:
    pass
