In [None]:
# Databricks notebook source
import mlflow
from mlflow.tracking import MlflowClient
import os
import time
import sys
import traceback

print("=" * 70)
print("PRODUCTION MODEL PROMOTION")
print("=" * 70)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
MODEL_NAME = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_model_uc"

STAGING_ALIAS = "staging"
PRODUCTION_ALIAS = "production"

# Quality thresholds for production promotion
MIN_R2_FOR_PROD = 0.80
MAX_RMSE_FOR_PROD = 50000.0  # Maximum acceptable RMSE

# =============================================================================
# MLflow Setup
# =============================================================================
try:
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
        print("✓ MLflow configured for Unity Catalog")
    
    client = MlflowClient()
    print("✓ MLflow client initialized")
except Exception as e:
    print(f"❌ Error initializing MLflow: {e}")
    sys.exit(1)

# =============================================================================
# VALIDATE STAGING MODEL BEFORE PROMOTION
# =============================================================================
print(f"\nValidating Staging model before promotion...")
print(f"Model: {MODEL_NAME}")

try:
    # Get all model versions
    model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")
    if not model_versions:
        print(f"❌ Error: No versions found for model {MODEL_NAME}")
        sys.exit(1)

    # ✅ Fetch aliases for each version individually (FIXED)
    staging_versions = []
    for v in model_versions:
        version_detail = client.get_model_version(MODEL_NAME, v.version)
        if STAGING_ALIAS in version_detail.aliases:  # ✅ FIXED: Removed ()
            staging_versions.append(version_detail)

    if not staging_versions:
        print(f"❌ Error: No model version found with alias '{STAGING_ALIAS}'")
        print("Please run UAT pipeline first to register and tag a staging model.")
        sys.exit(1)

    # Pick latest version numerically
    staging_version = max(staging_versions, key=lambda v: int(v.version))

    print(f"\n✓ Found Staging Model:")
    print(f"  Version: {staging_version.version}")
    print(f"  Status: {staging_version.status}")
    print(f"  Run ID: {staging_version.run_id}")
    print(f"  Aliases: {', '.join(staging_version.aliases)}")  # ✅ FIXED: Removed ()

    # Fetch metrics from MLflow run
    run = client.get_run(staging_version.run_id)
    metrics = run.data.metrics

    if not metrics:
        print(f"\n❌ No metrics found for run ID {staging_version.run_id}")
        sys.exit(1)

    # Extract metrics safely
    r2_score = metrics.get('test_r2_score', 0)
    rmse = metrics.get('test_rmse', float('inf'))
    cv_rmse = metrics.get('best_cv_rmse', 0)

    print(f"\nModel Training Metrics:")
    print(f"  Test R² Score: {r2_score:.4f}")
    print(f"  Test RMSE: ${rmse:,.2f}")
    print(f"  CV RMSE: ${cv_rmse:,.2f}")

    # Display training parameters
    print(f"\nModel Parameters:")
    for param in ['best_n_estimators', 'best_max_depth', 'best_min_samples_split', 'best_min_samples_leaf']:
        value = run.data.params.get(param, 'N/A')
        print(f"  {param}: {value}")

    # =============================================================================
    # QUALITY CHECKS
    # =============================================================================
    print(f"\nProduction Quality Checks:")

    all_passed = True
    if r2_score >= MIN_R2_FOR_PROD:
        print(f"  ✓ PASS: R² {r2_score:.4f} >= {MIN_R2_FOR_PROD}")
    else:
        print(f"  ✗ FAIL: R² {r2_score:.4f} < {MIN_R2_FOR_PROD}")
        all_passed = False

    if rmse <= MAX_RMSE_FOR_PROD:
        print(f"  ✓ PASS: RMSE ${rmse:,.2f} <= ${MAX_RMSE_FOR_PROD:,.2f}")
    else:
        print(f"  ✗ FAIL: RMSE ${rmse:,.2f} > ${MAX_RMSE_FOR_PROD:,.2f}")
        all_passed = False

    if not all_passed:
        print(f"\n{'=' * 70}")
        print("❌ PROMOTION DENIED - Model did not meet production quality thresholds")
        print(f"{'=' * 70}")
        sys.exit(1)

    print(f"\n✓ Quality validation PASSED - Model eligible for production.")

except Exception as e:
    print(f"❌ Error during validation: {e}")
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# WAIT FOR MODEL TO BE READY
# =============================================================================
print(f"\nChecking model readiness...")

max_wait_time = 300
elapsed_time = 0
check_interval = 5

while elapsed_time < max_wait_time:
    current_version = client.get_model_version(MODEL_NAME, staging_version.version)
    status = current_version.status

    if status == "READY":
        print(f"✓ Model version {staging_version.version} is READY")
        break
    elif status == "FAILED_REGISTRATION":
        print(f"✗ Error: Model registration failed")
        sys.exit(1)
    else:
        print(f"  Status: {status} - waiting {check_interval}s...")
        time.sleep(check_interval)
        elapsed_time += check_interval

if elapsed_time >= max_wait_time:
    print(f"✗ Timeout: Model not ready after {max_wait_time}s")
    sys.exit(1)

# =============================================================================
# PROMOTE TO PRODUCTION
# =============================================================================
print(f"\n{'=' * 70}")
print("PROMOTING MODEL TO PRODUCTION")
print(f"{'=' * 70}")

try:
    # Find previous production version (if exists)
    old_prod_version = None
    prod_versions = []
    for v in model_versions:
        version_detail = client.get_model_version(MODEL_NAME, v.version)
        if PRODUCTION_ALIAS in version_detail.aliases:  # ✅ FIXED: Removed ()
            prod_versions.append(version_detail)

    if prod_versions:
        old_prod_version = max(prod_versions, key=lambda v: int(v.version))
        print(f"\nℹ Previous production version found: v{old_prod_version.version}")

    # Assign Production alias to the staging version
    print(f"\nSetting '{PRODUCTION_ALIAS}' alias on version {staging_version.version}...")
    client.set_registered_model_alias(MODEL_NAME, PRODUCTION_ALIAS, staging_version.version)
    print("✓ Production alias set successfully")

    time.sleep(2)  # small delay for consistency

    updated_version = client.get_model_version(MODEL_NAME, staging_version.version)
    if PRODUCTION_ALIAS in updated_version.aliases:  # ✅ FIXED: Removed ()
        print(f"\n{'=' * 70}")
        print("✅ PROMOTION SUCCESSFUL")
        print(f"{'=' * 70}")
        print(f"\nModel Details:")
        print(f"  Version: {staging_version.version}")
        print(f"  Aliases: {', '.join(updated_version.aliases)}")  # ✅ FIXED: Removed ()
        print(f"  Status: PRODUCTION READY")
        print(f"\nModel Performance:")
        print(f"  R² Score: {r2_score:.4f}")
        print(f"  RMSE: ${rmse:,.2f}")
        print(f"  CV RMSE: ${cv_rmse:,.2f}")
        print(f"\nModel URI:")
        print(f"  models:/{MODEL_NAME}@{PRODUCTION_ALIAS}")

        if old_prod_version:
            print(f"\nVersion History:")
            print(f"  Previous Production: v{old_prod_version.version}")
            print(f"  New Production: v{staging_version.version}")
        print(f"{'=' * 70}")

    else:
        print(f"⚠ Warning: Alias may not have been set properly. Please verify in MLflow UI.")

except Exception as e:
    print(f"\n✗ Error during promotion: {e}")
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# FINAL MESSAGE
# =============================================================================
print(f"\n{'=' * 70}")
print("NEXT STEPS:")
print(f"{'=' * 70}")
print("1. Create/Update Serving Endpoint with Production model")
print("2. Run production inference tests")
print("3. Monitor model performance in production")
print("4. Set up alerts for model drift")
print(f"{'=' * 70}")

# Exit cleanly for Databricks pipeline
try:
    dbutils.notebook.exit("PROMOTION_SUCCESS")
except:
    pass