In [None]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput,
    TrafficConfig,
    Route
)
import mlflow
from mlflow.tracking import MlflowClient
import time
import sys
import os

print("=" * 70)
print("PRODUCTION SERVING ENDPOINT DEPLOYMENT")
print("=" * 70)

# =============================================================================
# CONFIGURATION
# =============================================================================
# Model details
UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"
MODEL_NAME = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_model_uc"
PRODUCTION_ALIAS = "production"

# Endpoint configuration
ENDPOINT_NAME = "house-price-prediction-prod"

# Compute configuration (for Community Edition - minimal resources)
WORKLOAD_SIZE = "Small"  # Options: Small, Medium, Large
SCALE_TO_ZERO_ENABLED = True  # Cost optimization

# =============================================================================
# INITIALIZE CLIENTS
# =============================================================================
try:
    w = WorkspaceClient()
    print("Databricks Workspace Client initialized")
    
    # Initialize MLflow client to get model version
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_registry_uri("databricks-uc")
    
    mlflow_client = MlflowClient()
    print("MLflow Client initialized")
    
except Exception as e:
    print(f"Error initializing clients: {e}")
    print("Ensure DATABRICKS_HOST and DATABRICKS_TOKEN are set")
    sys.exit(1)

# =============================================================================
# GET PRODUCTION MODEL VERSION NUMBER
# =============================================================================
print(f"\nResolving Production model version...")

try:
    # Get all model versions
    model_versions = mlflow_client.search_model_versions(f"name='{MODEL_NAME}'")
    
    if not model_versions:
        print(f"Error: No versions found for model {MODEL_NAME}")
        sys.exit(1)
    
    # Find the version with Production alias
    production_version_number = None
    
    for version in model_versions:
        # Get full version details with aliases
        full_version = mlflow_client.get_model_version(MODEL_NAME, version.version)
        
        if PRODUCTION_ALIAS in full_version.aliases:
            production_version_number = full_version.version
            print(f"Found Production model version: {production_version_number}")
            print(f"  Run ID: {full_version.run_id}")
            print(f"  Status: {full_version.status}")
            break
    
    if not production_version_number:
        print(f"Error: No model version found with '{PRODUCTION_ALIAS}' alias")
        print("Please run production promotion pipeline first")
        sys.exit(1)
        
except Exception as e:
    print(f"Error getting production model version: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# CHECK IF ENDPOINT ALREADY EXISTS
# =============================================================================
print(f"\nChecking if endpoint '{ENDPOINT_NAME}' exists...")

endpoint_exists = False
existing_endpoint = None

try:
    for endpoint in w.serving_endpoints.list():
        if endpoint.name == ENDPOINT_NAME:
            endpoint_exists = True
            existing_endpoint = endpoint
            print(f"Found existing endpoint: {ENDPOINT_NAME}")
            print(f"  State: {endpoint.state.config_update if endpoint.state else 'Unknown'}")
            break
    
    if not endpoint_exists:
        print(f"Endpoint does not exist - will create new")
except Exception as e:
    print(f"Error checking endpoints: {e}")

# =============================================================================
# PREPARE ENDPOINT CONFIGURATION
# =============================================================================
print(f"\nPreparing endpoint configuration...")

# Model serving configuration
# FIX: Use version number instead of alias
served_entity = ServedEntityInput(
    entity_name=MODEL_NAME,
    entity_version=production_version_number,  # Use actual version number
    workload_size=WORKLOAD_SIZE,
    scale_to_zero_enabled=SCALE_TO_ZERO_ENABLED
)

# Traffic configuration (100% to production model)
traffic_config = TrafficConfig(
    routes=[
        Route(
            served_model_name=f"{MODEL_NAME.replace('.', '_')}-{production_version_number}",
            traffic_percentage=100
        )
    ]
)

# Core endpoint configuration
endpoint_config = EndpointCoreConfigInput(
    name=ENDPOINT_NAME,
    served_entities=[served_entity]
)

print(f"\nEndpoint Configuration:")
print(f"  Name: {ENDPOINT_NAME}")
print(f"  Model: {MODEL_NAME}")
print(f"  Version: {production_version_number} (Production alias)")
print(f"  Workload Size: {WORKLOAD_SIZE}")
print(f"  Scale to Zero: {SCALE_TO_ZERO_ENABLED}")

# =============================================================================
# CREATE OR UPDATE ENDPOINT
# =============================================================================
print(f"\n{'=' * 70}")

try:
    if endpoint_exists:
        print("UPDATING EXISTING ENDPOINT")
        print(f"{'=' * 70}")
        
        # Update endpoint configuration
        w.serving_endpoints.update_config(
            name=ENDPOINT_NAME,
            served_entities=[served_entity]
        )
        
        print(f"Endpoint update initiated for: {ENDPOINT_NAME}")
        
    else:
        print("CREATING NEW ENDPOINT")
        print(f"{'=' * 70}")
        
        # Create new endpoint
        endpoint = w.serving_endpoints.create(
            name=ENDPOINT_NAME,
            config=endpoint_config
        )
        
        print(f"Endpoint creation initiated: {ENDPOINT_NAME}")
    
    # ==========================================================================
    # WAIT FOR ENDPOINT TO BE READY
    # ==========================================================================
    print(f"\nWaiting for endpoint to be ready...")
    print("(This may take 5-10 minutes for first deployment)")
    
    max_wait_time = 600  # 10 minutes
    check_interval = 15
    elapsed_time = 0
    
    while elapsed_time < max_wait_time:
        try:
            endpoint_status = w.serving_endpoints.get(name=ENDPOINT_NAME)
            
            state = endpoint_status.state
            config_update = str(state.config_update) if state and state.config_update else None
            ready_state = str(state.ready) if state and state.ready else None
            
            # Check if endpoint is ready (comparing enum values as strings)
            if config_update and "NOT_UPDATING" in config_update and ready_state and "READY" in ready_state:
                print(f"\nEndpoint is READY!")
                break
            elif config_update and ("UPDATE_FAILED" in config_update or "CREATION_FAILED" in config_update):
                print(f"\nEndpoint deployment FAILED")
                print(f"State: {config_update}")
                if ready_state:
                    print(f"Ready State: {ready_state}")
                sys.exit(1)
            else:
                status_msg = config_update if config_update else "INITIALIZING"
                print(f"  Status: {status_msg}, Ready: {ready_state} - elapsed: {elapsed_time}s")
                time.sleep(check_interval)
                elapsed_time += check_interval
                
        except Exception as e:
            print(f"  Error checking status: {e}")
            time.sleep(check_interval)
            elapsed_time += check_interval
    
    if elapsed_time >= max_wait_time:
        print(f"\nTimeout: Endpoint not ready after {max_wait_time}s")
        print("Check Databricks UI for deployment status")
        sys.exit(1)
    
    # ==========================================================================
    # GET ENDPOINT DETAILS
    # ==========================================================================
    endpoint_info = w.serving_endpoints.get(name=ENDPOINT_NAME)
    
    print(f"\n{'=' * 70}")
    print("ENDPOINT DEPLOYMENT SUCCESSFUL")
    print(f"{'=' * 70}")
    print(f"\nEndpoint Details:")
    print(f"  Name: {endpoint_info.name}")
    
    # Convert enum to string for display
    ready_status = str(endpoint_info.state.ready) if endpoint_info.state and endpoint_info.state.ready else 'Unknown'
    print(f"  State: {ready_status}")
    
    # Extract endpoint URL (may vary based on Databricks workspace)
    if hasattr(endpoint_info, 'url'):
        print(f"  URL: {endpoint_info.url}")
    
    print(f"\n  Served Entities:")
    if endpoint_info.config and endpoint_info.config.served_entities:
        for entity in endpoint_info.config.served_entities:
            print(f"    - {entity.entity_name} (Version: {entity.entity_version})")
            print(f"      Workload: {entity.workload_size}")
            print(f"      Scale to Zero: {entity.scale_to_zero_enabled}")
    
    print(f"\n{'=' * 70}")
    print("ENDPOINT READY FOR INFERENCE")
    print(f"{'=' * 70}")
    
except Exception as e:
    print(f"\nError during endpoint deployment: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# =============================================================================
# TEST ENDPOINT WITH SAMPLE DATA
# =============================================================================
print(f"\nTesting endpoint with sample data...")

try:
    # Sample input data
    sample_data = {
        "dataframe_records": [
            {
                "sq_feet": 900,
                "num_bedrooms": 3,
                "num_bathrooms": 2,
                "year_built": 2015,
                "location_score": 7.5
            }
        ]
    }
    
    print(f"\nSample Input:")
    print(f"  Square Feet: {sample_data['dataframe_records'][0]['sq_feet']}")
    print(f"  Bedrooms: {sample_data['dataframe_records'][0]['num_bedrooms']}")
    print(f"  Bathrooms: {sample_data['dataframe_records'][0]['num_bathrooms']}")
    print(f"  Year Built: {sample_data['dataframe_records'][0]['year_built']}")
    print(f"  Location Score: {sample_data['dataframe_records'][0]['location_score']}")
    
    # Query endpoint
    response = w.serving_endpoints.query(
        name=ENDPOINT_NAME,
        dataframe_records=sample_data["dataframe_records"]
    )
    
    print(f"\nPrediction Response:")
    print(f"  {response}")
    
    print(f"\nEndpoint test SUCCESSFUL")
    
except Exception as e:
    print(f"\nWarning: Endpoint test failed: {e}")
    print("Endpoint is deployed but test query failed")
    print("This is common in Community Edition - endpoint may still work")

# =============================================================================
# USAGE INSTRUCTIONS
# =============================================================================
print(f"\n{'=' * 70}")
print("HOW TO USE THIS ENDPOINT")
print(f"{'=' * 70}")

print(f"""
1. Using Python SDK:
   from databricks.sdk import WorkspaceClient
   
   w = WorkspaceClient()
   response = w.serving_endpoints.query(
       name="{ENDPOINT_NAME}",
       dataframe_records=[{{
           "sq_feet": 2000,
           "num_bedrooms": 3,
           "num_bathrooms": 2,
           "year_built": 2015,
           "location_score": 7.5
       }}]
   )

2. Using REST API:
   POST https://<databricks-instance>/serving-endpoints/{ENDPOINT_NAME}/invocations
   Headers: Authorization: Bearer <token>
   Body: {{"dataframe_records": [...]}}

3. Monitor endpoint in Databricks UI:
   Machine Learning > Serving > {ENDPOINT_NAME}
   
4. Model Details:
   Model: {MODEL_NAME}
   Version: {production_version_number}
   Alias: {PRODUCTION_ALIAS}
""")

print(f"{'=' * 70}")

# Success exit
try:
    dbutils.notebook.exit("ENDPOINT_READY")
except:
    pass