In [None]:
# Databricks notebook source
import mlflow
from mlflow.tracking import MlflowClient
import time
import os

print("üöÄ UAT Staging Promotion Started...")

# =======================================================
# ‚úÖ MLflow Registry (Unity Catalog)
# =======================================================
try:
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
        print("‚úÖ Using Unity Catalog Registry")
    client = MlflowClient()
except Exception as e:
    print(f"‚ùå Failed to initialize MLflow client: {e}")
    raise e


# =======================================================
# ‚úÖ FIXED MODEL NAME (aligned with training + registration)
# =======================================================
UC_CATALOG = "workspace"
UC_SCHEMA = "ml"
MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.house_price_xgboost_uc2"

STAGING_ALIAS = "Staging"
METRIC_KEY = "test_rmse"
TOL = 1e-6  # float tolerance


# =======================================================
# ‚úÖ Helper: Wait until model version is READY
# =======================================================
def wait_until_ready(client, model_name, version, timeout=300):
    start = time.time()
    while time.time() - start < timeout:
        mv = client.get_model_version(model_name, version)
        status = mv.status
        print(f"‚è≥ Model v{version} status = {status}")

        if status == "READY":
            return True
        elif status == "FAILED_REGISTRATION":
            print("‚ùå Model registration failed.")
            return False
        
        time.sleep(5)

    print("‚è∞ Timeout: Model is still not READY")
    return False


# =======================================================
# ‚úÖ Helper: Get metric from run (with fallback to tags)
# =======================================================
def get_metric_from_run(client, model_name, version, run_id):
    """
    Try to get metric from run. If not found, check model version tags.
    """
    metric_value = None
    
    # Method 1: Try to get from run metrics
    try:
        run = client.get_run(run_id)
        metric_value = run.data.metrics.get(METRIC_KEY, None)
        if metric_value is not None:
            print(f"  ‚úì Metric found in run metrics: {metric_value:.6f}")
            return metric_value
    except Exception as e:
        print(f"  ‚ö† Could not fetch run {run_id}: {e}")
    
    # Method 2: Try to get from model version tags
    try:
        mv = client.get_model_version(model_name, version)
        metric_tag = mv.tags.get("metric_rmse", None)
        if metric_tag:
            metric_value = float(metric_tag)
            print(f"  ‚úì Metric found in model tags: {metric_value:.6f}")
            return metric_value
    except Exception as e:
        print(f"  ‚ö† Could not fetch metric from tags: {e}")
    
    print(f"  ‚ö† No metric found for version {version}")
    return None


# =======================================================
# ‚úÖ Step 1: Find Latest Model Version
# =======================================================
print(f"\n{'='*70}")
print("üìã STEP 1: Finding Latest Model Version")
print(f"{'='*70}")

model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")

if not model_versions:
    print(f"‚ùå No versions found for model: {MODEL_NAME}")
    raise SystemExit

latest_version = sorted(model_versions, key=lambda m: int(m.version), reverse=True)[0]
new_version = latest_version.version
new_run_id = latest_version.run_id

print(f"\n‚úÖ Latest Registered Model Version: v{new_version}")
print(f"   Run ID: {new_run_id}")

# Get metric for new version
print(f"\nüîç Fetching metric for new version v{new_version}...")
new_metric = get_metric_from_run(client, MODEL_NAME, new_version, new_run_id)

if new_metric is None:
    print(f"‚ùå ERROR: Could not find {METRIC_KEY} for new version v{new_version}")
    print("   This version cannot be evaluated. Exiting.")
    raise SystemExit

print(f"‚úÖ New Model {METRIC_KEY}: {new_metric:.6f}")


# =======================================================
# ‚úÖ Step 2: Find existing Staging model (alias)
# =======================================================
print(f"\n{'='*70}")
print("üìã STEP 2: Checking Current STAGING Model")
print(f"{'='*70}")

try:
    staging_version = client.get_model_version_by_alias(MODEL_NAME, STAGING_ALIAS)
    old_version = staging_version.version
    old_run_id = staging_version.run_id

    print(f"\nüìå Current STAGING Version: v{old_version}")
    print(f"   Run ID: {old_run_id}")
    
    # Get metric for old version
    print(f"\nüîç Fetching metric for staging version v{old_version}...")
    old_metric = get_metric_from_run(client, MODEL_NAME, old_version, old_run_id)
    
    if old_metric is not None:
        print(f"üìå Current STAGING {METRIC_KEY}: {old_metric:.6f}")
    else:
        print(f"‚ö†Ô∏è WARNING: Could not find {METRIC_KEY} for staging v{old_version}")
        print("   Will promote new model by default.")

except Exception as e:
    print(f"\n‚ÑπÔ∏è No current STAGING model found: {e}")
    print("   Will promote latest model to staging.")
    staging_version = None
    old_metric = None


# =======================================================
# ‚úÖ Step 3: Compare metrics (lower RMSE = better)
# =======================================================
print(f"\n{'='*70}")
print("üìã STEP 3: Metric Comparison")
print(f"{'='*70}")

promote = False
promotion_reason = ""

if staging_version is None:
    promote = True
    promotion_reason = "No existing staging model"
    print(f"\nüü¢ DECISION: PROMOTE")
    print(f"   Reason: {promotion_reason}")
    
elif old_metric is None:
    promote = True
    promotion_reason = "Staging model has no metric (old version)"
    print(f"\nüü¢ DECISION: PROMOTE")
    print(f"   Reason: {promotion_reason}")
    
else:
    # Both metrics exist - compare them
    print(f"\nüìä Metric Comparison:")
    print(f"   New Model (v{new_version}):     {METRIC_KEY} = {new_metric:.6f}")
    print(f"   Staging Model (v{old_version}): {METRIC_KEY} = {old_metric:.6f}")
    print(f"   Improvement:                     {old_metric - new_metric:.6f}")
    
    if new_metric < old_metric - TOL:
        promote = True
        improvement_pct = ((old_metric - new_metric) / old_metric) * 100
        promotion_reason = f"New model is better (improvement: {improvement_pct:.2f}%)"
        print(f"\nüü¢ DECISION: PROMOTE")
        print(f"   Reason: {promotion_reason}")
        
    elif abs(new_metric - old_metric) <= TOL:
        print(f"\nüü° DECISION: NO PROMOTION")
        print(f"   Reason: New model performance is same as staging (within tolerance)")
        print(f"   Keeping existing staging version v{old_version}")
        
    else:
        print(f"\n‚õî DECISION: NO PROMOTION")
        print(f"   Reason: New model is WORSE than staging")
        print(f"   Degradation: {(new_metric - old_metric):.6f}")
        print(f"   Keeping existing staging version v{old_version}")


# =======================================================
# ‚úÖ Step 4: Promote using alias = "Staging"
# =======================================================
if promote:
    print(f"\n{'='*70}")
    print("üìã STEP 4: Promoting Model to STAGING")
    print(f"{'='*70}")
    
    print(f"\n‚è≥ Waiting for model v{new_version} to become READY...")
    if wait_until_ready(client, MODEL_NAME, new_version):
        
        try:
            # Set the alias
            client.set_registered_model_alias(
                name=MODEL_NAME,
                alias=STAGING_ALIAS,
                version=new_version
            )

            print(f"\n{'='*70}")
            print("‚úÖ‚úÖ PROMOTION SUCCESSFUL ‚úÖ‚úÖ")
            print(f"{'='*70}")
            print(f"   Model Name: {MODEL_NAME}")
            print(f"   New STAGING Version: v{new_version}")
            print(f"   {METRIC_KEY}: {new_metric:.6f}")
            print(f"   Reason: {promotion_reason}")
            print(f"{'='*70}\n")
            
        except Exception as e:
            print(f"\n‚ùå Failed to set alias: {e}")
            import traceback
            traceback.print_exc()
            raise
    else:
        print("\n‚ùå Promotion failed: Model did not become READY in time.")
        raise SystemExit
        
else:
    print(f"\n{'='*70}")
    print("‚úÖ STAGING UNCHANGED")
    print(f"{'='*70}")
    if staging_version:
        print(f"   Current STAGING Version: v{old_version}")
        if old_metric:
            print(f"   {METRIC_KEY}: {old_metric:.6f}")
    print(f"{'='*70}\n")

print("‚ú® UAT Staging Promotion Process Completed!")