In [None]:
# Databricks notebook source
# =============================================================================
# 🚀 AUTOMATED PRODUCTION PROMOTION - UAT VALIDATION BASED
# =============================================================================
%pip install xgboost
import mlflow
from mlflow.tracking import MlflowClient
import os
import time
import sys
import traceback
from pyspark.sql import SparkSession
from datetime import datetime

print("=" * 80)
print("🚀 AUTOMATED PRODUCTION PROMOTION - UAT VALIDATION BASED")
print("=" * 80)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"

STAGING_ALIAS = "staging"
PRODUCTION_ALIAS = "production"

# UAT validation thresholds (same as inference script)
MAPE_THRESHOLD = 15.0
R2_THRESHOLD = 0.75

# Production quality thresholds (stricter than UAT)
MIN_R2_FOR_PROD = 0.80
MAX_MAPE_FOR_PROD = 10.0

# =============================================================================
# INITIALIZATION
# =============================================================================
spark = SparkSession.builder.appName("Auto_Production_Promotion").getOrCreate()

try:
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
        print("✓ MLflow configured for Unity Catalog")
    
    client = MlflowClient()
    print("✓ MLflow client initialized\n")
except Exception as e:
    print(f"❌ Error initializing MLflow: {e}")
    sys.exit(1)

# =============================================================================
# 1️⃣ AUTO-DETECT LATEST EXPERIMENT AND MODEL TYPE
# =============================================================================
def get_latest_experiment_and_model(client):
    """Get latest experiment and infer model name"""
    experiments = client.search_experiments(view_type=mlflow.entities.ViewType.ACTIVE_ONLY)
    latest_exp = max(experiments, key=lambda exp: exp.last_update_time)
    
    # Infer model type from experiment name
    exp_lower = latest_exp.name.lower()
    model_type = "generic"
    model_map = {
        "xgboost": "house_price_xgboost_uc",
        "randomforest": "house_price_rf_uc",
        "lightgbm": "house_price_lightgbm_uc",
        "catboost": "house_price_catboost_uc",
        "gradientboosting": "house_price_gb_uc",
        "linear": "house_price_linear_uc",
        "decisiontree": "house_price_dt_uc",
        "logistic": "house_price_logreg_uc",
        "svm": "house_price_svm_uc"
    }

    for key, val in model_map.items():
        if key in exp_lower:
            model_type = key
            model_name = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{val}"
            break
    else:
        model_name = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_generic_uc"

    print(f"📘 Latest Experiment: {latest_exp.name}")
    print(f"✅ Detected Model Type: {model_type.upper()}")
    print(f"✅ Model Name: {model_name}\n")
    
    return latest_exp, model_name, model_type

# =============================================================================
# 2️⃣ CHECK UAT VALIDATION RESULTS FROM DELTA TABLE
# =============================================================================
def check_uat_validation_results(model_name, model_type):
    """Check if latest UAT validation passed"""
    uat_table = f"workspace.default.uat_inference_{model_name.split('.')[-1]}"
    
    print(f"🔍 Checking UAT validation results from: {uat_table}")
    
    try:
        uat_df = spark.table(uat_table).toPandas()
        
        if uat_df.empty:
            print(f"❌ No UAT results found in {uat_table}")
            return None, None
        
        latest_result = uat_df.sort_values('timestamp', ascending=False).iloc[0]
        
        print(f"\n📊 Latest UAT Validation Result:")
        print(f"   • Timestamp: {latest_result['timestamp']}")
        print(f"   • Model Version: {latest_result['model_version']}")
        print(f"   • MAE: {latest_result['mae']:.3f}")
        print(f"   • RMSE: {latest_result['rmse']:.3f}")
        print(f"   • R²: {latest_result['r2']:.3f}")
        print(f"   • MAPE: {latest_result['mape']:.2f}%")
        print(f"   • UAT Status: {latest_result['uat_status']}")
        
        return latest_result, uat_df
        
    except Exception as e:
        print(f"❌ Error reading UAT results: {e}")
        return None, None

# =============================================================================
# 3️⃣ GET STAGING MODEL VERSION
# =============================================================================
def get_staging_model_version(client, model_name):
    """Get the latest staging model version"""
    try:
        model_versions = client.search_model_versions(f"name='{model_name}'")
        
        if not model_versions:
            print(f"❌ No versions found for model {model_name}")
            return None
        
        staging_versions = []
        for v in model_versions:
            version_detail = client.get_model_version(model_name, v.version)
            if STAGING_ALIAS in version_detail.aliases:
                staging_versions.append(version_detail)
        
        if not staging_versions:
            print(f"❌ No model version found with alias '{STAGING_ALIAS}'")
            return None
        
        staging_version = max(staging_versions, key=lambda v: int(v.version))
        
        print(f"\n✅ Found Staging Model:")
        print(f"   • Version: {staging_version.version}")
        print(f"   • Status: {staging_version.status}")
        print(f"   • Run ID: {staging_version.run_id}")
        print(f"   • Aliases: {', '.join(staging_version.aliases)}")
        
        return staging_version
        
    except Exception as e:
        print(f"❌ Error fetching staging model: {e}")
        return None

# =============================================================================
# 4️⃣ VALIDATE PRODUCTION READINESS
# =============================================================================
def validate_production_readiness(uat_result, staging_version, client):
    """Validate if model meets production criteria"""
    print(f"\n{'='*80}")
    print("🔍 PRODUCTION READINESS VALIDATION")
    print(f"{'='*80}")
    
    all_checks_passed = True
    
    if uat_result['uat_status'] == 'PASSED':
        print(f"✓ CHECK 1: UAT Validation Status = PASSED")
    else:
        print(f"✗ CHECK 1: UAT Validation Status = {uat_result['uat_status']} (Expected: PASSED)")
        all_checks_passed = False
    
    r2_score = uat_result['r2']
    if r2_score >= MIN_R2_FOR_PROD:
        print(f"✓ CHECK 2: R² Score {r2_score:.4f} >= {MIN_R2_FOR_PROD} (Production threshold)")
    else:
        print(f"✗ CHECK 2: R² Score {r2_score:.4f} < {MIN_R2_FOR_PROD}")
        all_checks_passed = False
    
    mape = uat_result['mape']
    if mape <= MAX_MAPE_FOR_PROD:
        print(f"✓ CHECK 3: MAPE {mape:.2f}% <= {MAX_MAPE_FOR_PROD}% (Production threshold)")
    else:
        print(f"✗ CHECK 3: MAPE {mape:.2f}% > {MAX_MAPE_FOR_PROD}%")
        all_checks_passed = False
    
    if staging_version.status == "READY":
        print(f"✓ CHECK 4: Model Status = READY")
    else:
        print(f"✗ CHECK 4: Model Status = {staging_version.status}")
        all_checks_passed = False
    
    if int(staging_version.version) == int(uat_result['model_version']):
        print(f"✓ CHECK 5: Model Version Match (v{staging_version.version})")
    else:
        print(f"⚠ CHECK 5: Model Version Mismatch (Staging: v{staging_version.version}, UAT: v{uat_result['model_version']})")
        print(f"   Will promote staging version v{staging_version.version}")
    
    print(f"{'='*80}\n")
    return all_checks_passed

# =============================================================================
# 5️⃣ PROMOTE MODEL TO PRODUCTION
# =============================================================================
def promote_to_production(client, model_name, staging_version, uat_result, model_type):
    """Promote staging model to production"""
    print(f"{'='*80}")
    print("🚀 PROMOTING MODEL TO PRODUCTION")
    print(f"{'='*80}\n")
    
    try:
        model_versions = client.search_model_versions(f"name='{model_name}'")
        old_prod_version = None
        
        for v in model_versions:
            version_detail = client.get_model_version(model_name, v.version)
            if PRODUCTION_ALIAS in version_detail.aliases:
                old_prod_version = version_detail
                break
        
        if old_prod_version:
            print(f"ℹ️ Previous production version: v{old_prod_version.version}")
        
        client.set_registered_model_alias(model_name, PRODUCTION_ALIAS, staging_version.version)
        print("✅ Production alias set successfully\n")
        
        time.sleep(2)
        updated_version = client.get_model_version(model_name, staging_version.version)
        
        if PRODUCTION_ALIAS in updated_version.aliases:
            print(f"✅ PROMOTION SUCCESSFUL for {model_name}")
            log_promotion_to_delta(model_name, staging_version.version,
                                   old_prod_version.version if old_prod_version else None,
                                   uat_result, model_type)
            return True
        else:
            print("⚠️ Alias may not have been set properly.")
            return False
    except Exception as e:
        print(f"❌ Error during promotion: {e}")
        traceback.print_exc()
        return False

# =============================================================================
# 6️⃣ LOG PROMOTION TO DELTA TABLE (UPDATED)
# =============================================================================
def log_promotion_to_delta(model_name, new_version, old_version, uat_result, model_type):
    """Log promotion event to model-specific Delta table"""
    import pandas as pd
    
    # Table name now depends on model type (e.g., model_promotion_rf)
    promotion_table = f"workspace.default.model_promotion_{model_type}"
    
    promotion_df = pd.DataFrame([{
        "timestamp": datetime.now(),
        "model_name": model_name,
        "promoted_version": int(new_version),
        "previous_version": int(old_version) if old_version else None,
        "r2_score": uat_result['r2'],
        "rmse": uat_result['rmse'],
        "mae": uat_result['mae'],
        "mape": uat_result['mape'],
        "promotion_status": "SUCCESS"
    }])
    
    try:
        spark_df = spark.createDataFrame(promotion_df)
        spark_df.write \
            .format("delta") \
            .mode("append") \
            .option("mergeSchema", "true") \
            .saveAsTable(promotion_table)
        print(f"✅ Promotion logged to Delta table: {promotion_table}\n")
    except Exception as e:
        print(f"⚠️ Could not log promotion to Delta: {e}\n")

# =============================================================================
# MAIN EXECUTION
# =============================================================================
if __name__ == "__main__":
    try:
        latest_exp, model_name, model_type = get_latest_experiment_and_model(client)
        
        uat_result, _ = check_uat_validation_results(model_name, model_type)
        if uat_result is None:  # <-- fixed here
            print("❌ Missing UAT validation results.")
            sys.exit(1)
        
        staging_version = get_staging_model_version(client, model_name)
        if staging_version is None:
            print("❌ Missing staging model version.")
            sys.exit(1)
        
        if not validate_production_readiness(uat_result, staging_version, client):
            print("❌ PROMOTION DENIED – Model did not meet criteria.")
            sys.exit(1)
        
        success = promote_to_production(client, model_name, staging_version, uat_result, model_type)
        if success:
            print("🎯 Production Promotion Completed Successfully.")

    except Exception as e:
        print(f"\n❌ UNEXPECTED ERROR: {e}")
        traceback.print_exc()
        sys.exit(1)
