In [None]:
# Databricks notebook source
# ==================================================================================
# üöÄ PRODUCTION PROMOTION SCRIPT ‚Äî CLEAN VERSION (Single Model Architecture)
# ==================================================================================

import mlflow
from mlflow.tracking import MlflowClient
import os
import time
import sys

print("=" * 80)
print("üöÄ PRODUCTION PROMOTION STARTED")
print("=" * 80)

# ==================================================================================
# ‚úÖ CONFIGURATION (Fixed model name ‚Äî MUST match training + registration)
# ==================================================================================
UC_CATALOG = "workspace"
UC_SCHEMA = "ml"
MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.house_price_xgboost_uc"

PRODUCTION_ALIAS = "production"
STAGING_ALIAS = "staging"

METRIC_KEY = "test_rmse"
TOL = 1e-6  # threshold to treat metrics as identical


# ==================================================================================
# ‚úÖ MLflow Initialization (Unity Catalog)
# ==================================================================================
try:
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
        print("‚úÖ MLflow connected to Unity Catalog")
    client = MlflowClient()
except Exception as e:
    print(f"‚ùå MLflow client creation failed: {e}")
    raise e


# ==================================================================================
# ‚úÖ Helper: Wait until model version is READY
# ==================================================================================
def wait_until_ready(client, model_name, version, timeout=300):
    start = time.time()
    while time.time() - start < timeout:
        mv = client.get_model_version(model_name, version)
        if mv.status == "READY":
            return True
        if mv.status == "FAILED_REGISTRATION":
            print("‚ùå Registration failed")
            return False
        time.sleep(5)
    print("‚è∞ Timeout: model not ready")
    return False


# ==================================================================================
# ‚úÖ Helper: get metric of a run
# ==================================================================================
def get_metric(client, run_id):
    try:
        run = client.get_run(run_id)
        return run.data.metrics.get(METRIC_KEY, None)
    except:
        return None


# ==================================================================================
# ‚úÖ STEP 1: Pick latest staging version
# ==================================================================================
def get_staging_version(client):
    versions = client.search_model_versions(f"name='{MODEL_NAME}'")

    staging_versions = []
    for v in versions:
        mv = client.get_model_version(MODEL_NAME, v.version)
        if STAGING_ALIAS in mv.aliases:
            staging_versions.append(mv)

    if not staging_versions:
        print("‚ùå No staging model found")
        return None

    staging_version = max(staging_versions, key=lambda x: int(x.version))
    print(f"‚úÖ Staging Version Found: v{staging_version.version}")
    return staging_version


# ==================================================================================
# ‚úÖ STEP 2: Pick current production version (if any)
# ==================================================================================
def get_prod_version(client):
    versions = client.search_model_versions(f"name='{MODEL_NAME}'")

    for v in versions:
        mv = client.get_model_version(MODEL_NAME, v.version)
        if PRODUCTION_ALIAS in mv.aliases:
            print(f"‚úÖ Current Production Version: v{mv.version}")
            return mv

    print("‚ÑπÔ∏è No production model exists yet")
    return None


# ==================================================================================
# ‚úÖ STEP 3: Compare metrics (RMSE) and decide promotion
# ==================================================================================
def should_promote(new_rmse, old_rmse):
    if old_rmse is None:
        print("üü¢ No production model ‚Üí Promote Staging to Production")
        return True

    print(f"\nüìä Metric Comparison")
    print(f"   New (Staging) RMSE: {new_rmse}")
    print(f"   Old (Production) RMSE: {old_rmse}")

    if new_rmse < old_rmse - TOL:
        print("üü¢ New staging model is better ‚Üí Promote")
        return True
    else:
        print("‚õî Staging model is NOT better ‚Üí No promotion")
        return False


# ==================================================================================
# ‚úÖ STEP 4: Promote Staging ‚Üí Production
# ==================================================================================
def promote_to_production(client, version):
    print(f"\n‚è≥ Waiting for v{version} to become READY...")
    if not wait_until_ready(client, MODEL_NAME, version):
        print("‚ùå Model not ready for promotion")
        return False

    client.set_registered_model_alias(
        name=MODEL_NAME,
        alias=PRODUCTION_ALIAS,
        version=version
    )

    print(f"‚úÖ‚úÖ SUCCESS: Promoted Staging v{version} ‚Üí PRODUCTION")
    return True


# ==================================================================================
# ‚úÖ MAIN EXECUTION
# ==================================================================================
if __name__ == "__main__":
    staging_mv = get_staging_version(client)
    if not staging_mv:
        sys.exit(1)

    prod_mv = get_prod_version(client)

    # Fetch metrics
    new_rmse = get_metric(client, staging_mv.run_id)
    old_rmse = get_metric(client, prod_mv.run_id) if prod_mv else None

    # Decide
    if should_promote(new_rmse, old_rmse):
        promote_to_production(client, staging_mv.version)
    else:
        print("\n‚úÖ Production model remains unchanged.")
