In [None]:
# Databricks notebook source
# =============================================================================
# 🚀 AUTOMATED PRODUCTION SERVING ENDPOINT DEPLOYMENT
# =============================================================================

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput,
    TrafficConfig,
    Route
)
import mlflow
from mlflow.tracking import MlflowClient
import time
import sys
import os

print("=" * 80)
print("🚀 AUTOMATED PRODUCTION SERVING ENDPOINT DEPLOYMENT")
print("=" * 80)

# =============================================================================
# CONFIGURATION
# =============================================================================
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
PRODUCTION_ALIAS = "production"

# Endpoint configuration
WORKLOAD_SIZE = "Small"  # Options: Small, Medium, Large
SCALE_TO_ZERO_ENABLED = True  # Cost optimization

# =============================================================================
# INITIALIZE CLIENTS
# =============================================================================
try:
    w = WorkspaceClient()
    print("✓ Databricks Workspace Client initialized")
    
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
    
    mlflow_client = MlflowClient()
    print("✓ MLflow Client initialized\n")
    
except Exception as e:
    print(f"❌ Error initializing clients: {e}")
    print("   Ensure DATABRICKS_HOST and DATABRICKS_TOKEN are set")
    sys.exit(1)

# =============================================================================
# 1️⃣ AUTO-DETECT LATEST EXPERIMENT AND MODEL
# =============================================================================
def get_latest_model_info(mlflow_client):
    """Auto-detect latest experiment and infer model name"""
    print("🔍 Auto-detecting latest model...")
    
    experiments = mlflow_client.search_experiments(view_type=mlflow.entities.ViewType.ACTIVE_ONLY)
    latest_exp = max(experiments, key=lambda exp: exp.last_update_time)
    
    # Infer model type from experiment name
    exp_lower = latest_exp.name.lower()
    if "xgboost" in exp_lower:
        model_type = "xgboost"
    elif "rf" in exp_lower or "randomforest" in exp_lower:
        model_type = "rf"
    elif "linear" in exp_lower:
        model_type = "linear"
    else:
        model_type = "generic"
    
    model_name = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_{model_type}_uc"
    endpoint_name = f"house-price-{model_type}-prod"
    
    print(f"📘 Latest Experiment: {latest_exp.name}")
    print(f"✅ Detected Model Type: {model_type.upper()}")
    print(f"✅ Model Name: {model_name}")
    print(f"✅ Endpoint Name: {endpoint_name}\n")
    
    return model_name, endpoint_name, model_type

# =============================================================================
# 2️⃣ GET PRODUCTION MODEL VERSION
# =============================================================================
def get_production_model_version(mlflow_client, model_name):
    """Get the production model version number"""
    print(f"🔍 Resolving Production model version...")
    
    try:
        model_versions = mlflow_client.search_model_versions(f"name='{model_name}'")
        
        if not model_versions:
            print(f"❌ Error: No versions found for model {model_name}")
            return None, None
        
        # Find version with production alias
        production_version = None
        
        for version in model_versions:
            full_version = mlflow_client.get_model_version(model_name, version.version)
            
            if PRODUCTION_ALIAS in full_version.aliases:
                production_version = full_version
                print(f"✅ Found Production model version: {full_version.version}")
                print(f"   • Run ID: {full_version.run_id}")
                print(f"   • Status: {full_version.status}")
                print(f"   • Aliases: {', '.join(full_version.aliases)}\n")
                break
        
        if not production_version:
            print(f"❌ Error: No model version found with '{PRODUCTION_ALIAS}' alias")
            print(f"   Please run production promotion pipeline first")
            return None, None
        
        return production_version.version, production_version
        
    except Exception as e:
        print(f"❌ Error getting production model version: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# =============================================================================
# 3️⃣ CHECK IF ENDPOINT EXISTS
# =============================================================================
def check_endpoint_exists(w, endpoint_name):
    """Check if serving endpoint already exists"""
    print(f"🔍 Checking if endpoint '{endpoint_name}' exists...")
    
    endpoint_exists = False
    existing_endpoint = None
    
    try:
        for endpoint in w.serving_endpoints.list():
            if endpoint.name == endpoint_name:
                endpoint_exists = True
                existing_endpoint = endpoint
                print(f"✅ Found existing endpoint: {endpoint_name}")
                if endpoint.state:
                    print(f"   • State: {endpoint.state.config_update if endpoint.state else 'Unknown'}\n")
                break
        
        if not endpoint_exists:
            print(f"ℹ️  Endpoint does not exist - will create new\n")
            
    except Exception as e:
        print(f"⚠️  Error checking endpoints: {e}\n")
    
    return endpoint_exists, existing_endpoint

# =============================================================================
# 4️⃣ CREATE OR UPDATE ENDPOINT
# =============================================================================
def deploy_endpoint(w, endpoint_name, model_name, production_version_number, endpoint_exists):
    """Create or update serving endpoint"""
    
    print(f"{'='*80}")
    print("📦 ENDPOINT CONFIGURATION")
    print(f"{'='*80}")
    print(f"  • Name: {endpoint_name}")
    print(f"  • Model: {model_name}")
    print(f"  • Version: {production_version_number} (Production alias)")
    print(f"  • Workload Size: {WORKLOAD_SIZE}")
    print(f"  • Scale to Zero: {SCALE_TO_ZERO_ENABLED}\n")
    
    # Prepare served entity
    served_entity = ServedEntityInput(
        entity_name=model_name,
        entity_version=production_version_number,
        workload_size=WORKLOAD_SIZE,
        scale_to_zero_enabled=SCALE_TO_ZERO_ENABLED
    )
    
    try:
        if endpoint_exists:
            print(f"{'='*80}")
            print("🔄 UPDATING EXISTING ENDPOINT")
            print(f"{'='*80}\n")
            
            w.serving_endpoints.update_config(
                name=endpoint_name,
                served_entities=[served_entity]
            )
            
            print(f"✅ Endpoint update initiated for: {endpoint_name}\n")
            
        else:
            print(f"{'='*80}")
            print("🆕 CREATING NEW ENDPOINT")
            print(f"{'='*80}\n")
            
            endpoint_config = EndpointCoreConfigInput(
                name=endpoint_name,
                served_entities=[served_entity]
            )
            
            endpoint = w.serving_endpoints.create(
                name=endpoint_name,
                config=endpoint_config
            )
            
            print(f"✅ Endpoint creation initiated: {endpoint_name}\n")
        
        return True
        
    except Exception as e:
        print(f"❌ Error during endpoint deployment: {e}")
        import traceback
        traceback.print_exc()
        return False

# =============================================================================
# 5️⃣ WAIT FOR ENDPOINT TO BE READY
# =============================================================================
def wait_for_endpoint_ready(w, endpoint_name):
    """Wait for endpoint to become ready"""
    print(f"{'='*80}")
    print("⏳ WAITING FOR ENDPOINT TO BE READY")
    print(f"{'='*80}")
    print("(This may take 5-10 minutes for first deployment)\n")
    
    max_wait_time = 1200  # 20 minutes
    check_interval = 15
    elapsed_time = 0
    
    while elapsed_time < max_wait_time:
        try:
            endpoint_status = w.serving_endpoints.get(name=endpoint_name)
            
            state = endpoint_status.state
            config_update = str(state.config_update) if state and state.config_update else None
            ready_state = str(state.ready) if state and state.ready else None
            
            # Check if endpoint is ready
            if config_update and "NOT_UPDATING" in config_update and ready_state and "READY" in ready_state:
                print(f"\n✅ Endpoint is READY!\n")
                return True
            elif config_update and ("UPDATE_FAILED" in config_update or "CREATION_FAILED" in config_update):
                print(f"\n❌ Endpoint deployment FAILED")
                print(f"   • State: {config_update}")
                if ready_state:
                    print(f"   • Ready State: {ready_state}\n")
                return False
            else:
                status_msg = config_update if config_update else "INITIALIZING"
                print(f"  ⏳ Status: {status_msg}, Ready: {ready_state} - elapsed: {elapsed_time}s")
                time.sleep(check_interval)
                elapsed_time += check_interval
                
        except Exception as e:
            print(f"  ⚠️  Error checking status: {e}")
            time.sleep(check_interval)
            elapsed_time += check_interval
    
    print(f"\n⏱️  Timeout: Endpoint not ready after {max_wait_time}s")
    print("   Check Databricks UI for deployment status\n")
    return False

# =============================================================================
# 6️⃣ DISPLAY ENDPOINT DETAILS
# =============================================================================
def display_endpoint_info(w, endpoint_name, model_name, production_version_number):
    """Display endpoint details and usage instructions"""
    
    try:
        endpoint_info = w.serving_endpoints.get(name=endpoint_name)
        
        print(f"{'='*80}")
        print("✅ ENDPOINT DEPLOYMENT SUCCESSFUL")
        print(f"{'='*80}")
        print(f"\n📦 Endpoint Details:")
        print(f"   • Name: {endpoint_info.name}")
        
        ready_status = str(endpoint_info.state.ready) if endpoint_info.state and endpoint_info.state.ready else 'Unknown'
        print(f"   • State: {ready_status}")
        
        if hasattr(endpoint_info, 'url'):
            print(f"   • URL: {endpoint_info.url}")
        
        print(f"\n📊 Served Entities:")
        if endpoint_info.config and endpoint_info.config.served_entities:
            for entity in endpoint_info.config.served_entities:
                print(f"   • {entity.entity_name} (Version: {entity.entity_version})")
                print(f"     - Workload: {entity.workload_size}")
                print(f"     - Scale to Zero: {entity.scale_to_zero_enabled}")
        
        print(f"\n{'='*80}")
        print("🎯 ENDPOINT READY FOR INFERENCE")
        print(f"{'='*80}\n")
        
    except Exception as e:
        print(f"⚠️  Could not fetch endpoint details: {e}\n")

# =============================================================================
# 7️⃣ TEST ENDPOINT
# =============================================================================
def test_endpoint(w, endpoint_name):
    """Test endpoint with sample data"""
    print(f"{'='*80}")
    print("🧪 TESTING ENDPOINT WITH SAMPLE DATA")
    print(f"{'='*80}\n")
    
    try:
        sample_data = {
            "dataframe_records": [
                {
                    "sq_feet": 900,
                    "num_bedrooms": 3,
                    "num_bathrooms": 2,
                    "year_built": 2015,
                    "location_score": 7.5
                }
            ]
        }
        
        print(f"📥 Sample Input:")
        print(f"   • Square Feet: {sample_data['dataframe_records'][0]['sq_feet']}")
        print(f"   • Bedrooms: {sample_data['dataframe_records'][0]['num_bedrooms']}")
        print(f"   • Bathrooms: {sample_data['dataframe_records'][0]['num_bathrooms']}")
        print(f"   • Year Built: {sample_data['dataframe_records'][0]['year_built']}")
        print(f"   • Location Score: {sample_data['dataframe_records'][0]['location_score']}")
        
        response = w.serving_endpoints.query(
            name=endpoint_name,
            dataframe_records=sample_data["dataframe_records"]
        )
        
        print(f"\n📤 Prediction Response:")
        print(f"   {response}\n")
        
        print(f"✅ Endpoint test SUCCESSFUL\n")
        return True
        
    except Exception as e:
        print(f"\n⚠️  Warning: Endpoint test failed: {e}")
        print("   Endpoint is deployed but test query failed")
        print("   This is common in Community Edition - endpoint may still work\n")
        return False

# =============================================================================
# 8️⃣ DISPLAY USAGE INSTRUCTIONS
# =============================================================================
def display_usage_instructions(endpoint_name, model_name, production_version_number):
    """Display how to use the endpoint"""
    print(f"{'='*80}")
    print("📖 HOW TO USE THIS ENDPOINT")
    print(f"{'='*80}")
    
    print(f"""
1️⃣  Using Python SDK:
   from databricks.sdk import WorkspaceClient
   
   w = WorkspaceClient()
   response = w.serving_endpoints.query(
       name="{endpoint_name}",
       dataframe_records=[{{
           "sq_feet": 1000,
           "num_bedrooms": 3,
           "num_bathrooms": 2,
           "year_built": 2015,
           "location_score": 7.5
       }}]
   )

2️⃣  Using REST API:
   POST https://<databricks-instance>/serving-endpoints/{endpoint_name}/invocations
   Headers: Authorization: Bearer <token>
   Body: {{"dataframe_records": [...]}}

3️⃣  Monitor endpoint in Databricks UI:
   Machine Learning > Serving > {endpoint_name}
   
4️⃣  Model Details:
   • Model: {model_name}
   • Version: {production_version_number}
   • Alias: {PRODUCTION_ALIAS}
""")
    
    print(f"{'='*80}\n")

# =============================================================================
# MAIN EXECUTION
# =============================================================================
if __name__ == "__main__":
    try:
        # Step 1: Auto-detect model
        model_name, endpoint_name, model_type = get_latest_model_info(mlflow_client)
        
        # Step 2: Get production version
        production_version_number, production_version = get_production_model_version(
            mlflow_client, model_name
        )
        
        if not production_version_number:
            print(f"❌ Cannot proceed without production model")
            print(f"   Please run production promotion pipeline first\n")
            sys.exit(1)
        
        # Step 3: Check if endpoint exists
        endpoint_exists, existing_endpoint = check_endpoint_exists(w, endpoint_name)
        
        # Step 4: Deploy endpoint
        deployment_success = deploy_endpoint(
            w, endpoint_name, model_name, production_version_number, endpoint_exists
        )
        
        if not deployment_success:
            sys.exit(1)
        
        # Step 5: Wait for endpoint to be ready
        is_ready = wait_for_endpoint_ready(w, endpoint_name)
        
        if not is_ready:
            print(f"❌ Endpoint deployment did not complete successfully")
            sys.exit(1)
        
        # Step 6: Display endpoint info
        display_endpoint_info(w, endpoint_name, model_name, production_version_number)
        
        # Step 7: Test endpoint
        test_endpoint(w, endpoint_name)
        
        # Step 8: Display usage instructions
        display_usage_instructions(endpoint_name, model_name, production_version_number)
        
        # Success exit
        try:
            dbutils.notebook.exit("ENDPOINT_READY")
        except:
            pass
            
    except Exception as e:
        print(f"\n❌ UNEXPECTED ERROR: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)