##################################################################################
# Model Deployment Notebook
#
# This notebook deploys a validated Prophet forecasting model to a Model Serving
# endpoint and optionally promotes it to the "champion" alias.
#
# Parameters:
#
# * env                        - Environment (dev, staging, prod)
# * model_name                 - Three-level UC model name
# * model_version              - Model version to deploy (optional, defaults to "challenger")
# * serving_endpoint_name      - Name for the serving endpoint
# * promote_to_champion        - Whether to promote model to champion alias (true/false)
# * test_endpoint              - Whether to test the endpoint after deployment (true/false)
# * endpoint_config_json       - Optional JSON config for endpoint (defaults provided)
##################################################################################

# MAGIC %load_ext autoreload
# MAGIC %autoreload 2


In [None]:
# DBTITLE 1, Install dependencies
# MAGIC %pip install prophet databricks-sdk mlflow pandas requests
dbutils.library.restartPython()


In [None]:
# DBTITLE 1, Notebook arguments
dbutils.widgets.text("env", "dev", "Environment")
dbutils.widgets.text("model_name", "johannes_oehler.vectorlab.prophet_forecast", "Model Name")
dbutils.widgets.text("model_version", "", "Model Version (leave empty to use 'challenger')")
dbutils.widgets.text("serving_endpoint_name", "forecast_joe", "Serving Endpoint Name")
dbutils.widgets.dropdown("promote_to_champion", "true", ["true", "false"], "Promote to Champion")
dbutils.widgets.dropdown("test_endpoint", "true", ["true", "false"], "Test Endpoint After Deployment")
dbutils.widgets.text("endpoint_config_json", "", "Endpoint Config JSON (optional)")


In [None]:
# DBTITLE 1, Get parameters
import mlflow
from mlflow.tracking.client import MlflowClient

# Setup MLflow
client = MlflowClient(registry_uri="databricks-uc")
mlflow.set_registry_uri('databricks-uc')

# Get parameters
env = dbutils.widgets.get("env")
model_name = dbutils.widgets.get("model_name")
serving_endpoint_name = dbutils.widgets.get("serving_endpoint_name")
promote_to_champion = dbutils.widgets.get("promote_to_champion").lower() == "true"
test_endpoint = dbutils.widgets.get("test_endpoint").lower() == "true"
endpoint_config_json = dbutils.widgets.get("endpoint_config_json")

# Get model version from task values or widget
model_version = dbutils.jobs.taskValues.get("Train", "model_version", debugValue="")
if model_version == "":
    model_version = dbutils.widgets.get("model_version")

# If no version specified, use the "challenger" alias
if model_version == "":
    print("No model version specified. Using 'challenger' alias...")
    try:
        challenger_model = client.get_model_version_by_alias(model_name, "challenger")
        model_version = challenger_model.version
        print(f"Found challenger model version: {model_version}")
    except Exception as e:
        print(f"Error: Could not find 'challenger' alias for model {model_name}")
        print(f"Details: {e}")
        raise ValueError("No model version specified and no 'challenger' alias found. Please specify a model version.")

model_uri = f"models:/{model_name}/{model_version}"

print("\n=== Deployment Configuration ===")
print(f"Environment: {env}")
print(f"Model Name: {model_name}")
print(f"Model Version: {model_version}")
print(f"Model URI: {model_uri}")
print(f"Endpoint Name: {serving_endpoint_name}")
print(f"Promote to Champion: {promote_to_champion}")
print(f"Test Endpoint: {test_endpoint}")


In [None]:
# DBTITLE 1, Setup Databricks SDK
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput,
    AutoCaptureConfigInput,
)
import time
import json

# Initialize Workspace Client
w = WorkspaceClient()

print("Databricks Workspace Client initialized")


In [None]:
# DBTITLE 1, Check if endpoint exists
def endpoint_exists(endpoint_name):
    """Check if a serving endpoint already exists."""
    try:
        w.serving_endpoints.get(endpoint_name)
        return True
    except Exception:
        return False


def wait_for_endpoint_ready(endpoint_name, timeout_seconds=1800):
    """Wait for endpoint to be in READY state."""
    from databricks.sdk.service.serving import EndpointStateConfigUpdate, EndpointStateReady
    
    start_time = time.time()
    while time.time() - start_time < timeout_seconds:
        endpoint = w.serving_endpoints.get(endpoint_name)
        state = endpoint.state.config_update if endpoint.state else None
        ready = endpoint.state.ready if endpoint.state else None
        
        # Check if endpoint is ready (compare enum values, not strings)
        if state == EndpointStateConfigUpdate.NOT_UPDATING and ready == EndpointStateReady.READY:
            print(f"✓ Endpoint '{endpoint_name}' is READY")
            return True
        elif state == EndpointStateConfigUpdate.UPDATE_FAILED:
            print(f"✗ Endpoint '{endpoint_name}' update FAILED")
            return False
        
        print(f"  Waiting for endpoint... State: {state}, Ready: {ready}")
        time.sleep(30)
    
    print(f"✗ Timeout waiting for endpoint '{endpoint_name}' to be ready")
    return False


endpoint_already_exists = endpoint_exists(serving_endpoint_name)
print(f"Endpoint '{serving_endpoint_name}' exists: {endpoint_already_exists}")


In [None]:
# DBTITLE 1, Prepare endpoint configuration
# Parse custom endpoint config if provided, otherwise use defaults
if endpoint_config_json and endpoint_config_json.strip():
    try:
        endpoint_config = json.loads(endpoint_config_json)
        print("Using custom endpoint configuration")
    except json.JSONDecodeError as e:
        print(f"Error parsing endpoint_config_json: {e}")
        print("Falling back to default configuration")
        endpoint_config = {}
else:
    endpoint_config = {}

# Build served entity configuration
served_entity = ServedEntityInput(
    entity_name=model_name,
    entity_version=model_version,
    scale_to_zero_enabled=endpoint_config.get("scale_to_zero_enabled", True),
    workload_size=endpoint_config.get("workload_size", "Small"),
    workload_type=endpoint_config.get("workload_type", "CPU"),
)

# Enable inference table logging for monitoring
auto_capture_config = AutoCaptureConfigInput(
    catalog_name=model_name.split(".")[0],  # Extract catalog from model name
    schema_name=model_name.split(".")[1],   # Extract schema from model name
    table_name_prefix=serving_endpoint_name,
    enabled=endpoint_config.get("auto_capture_enabled", True),
)

print("\n=== Endpoint Configuration ===")
print(f"Scale to Zero: {served_entity.scale_to_zero_enabled}")
print(f"Workload Size: {served_entity.workload_size}")
print(f"Workload Type: {served_entity.workload_type}")
print(f"Auto Capture Enabled: {auto_capture_config.enabled}")
print(f"Inference Table: {auto_capture_config.catalog_name}.{auto_capture_config.schema_name}.{auto_capture_config.table_name_prefix}_*")


In [None]:
# DBTITLE 1, Create or update serving endpoint
print(f"\n{'='*60}")
if endpoint_already_exists:
    print(f"Updating existing endpoint: {serving_endpoint_name}")
    print(f"{'='*60}\n")
    
    # Update existing endpoint
    w.serving_endpoints.update_config(
        name=serving_endpoint_name,
        served_entities=[served_entity],
        auto_capture_config=auto_capture_config,
    )
    print(f"✓ Update initiated for endpoint '{serving_endpoint_name}'")
else:
    print(f"Creating new endpoint: {serving_endpoint_name}")
    print(f"{'='*60}\n")
    
    # Create new endpoint
    w.serving_endpoints.create(
        name=serving_endpoint_name,
        config=EndpointCoreConfigInput(
            served_entities=[served_entity],
            auto_capture_config=auto_capture_config,
        ),
    )
    print(f"✓ Creation initiated for endpoint '{serving_endpoint_name}'")

# Wait for endpoint to be ready
print("\nWaiting for endpoint to be ready...")
if not wait_for_endpoint_ready(serving_endpoint_name):
    raise Exception(f"Endpoint '{serving_endpoint_name}' failed to become ready")

print(f"\n{'='*60}")
print(f"✓ Endpoint '{serving_endpoint_name}' is now ready and serving model version {model_version}")
print(f"{'='*60}")


In [None]:
# DBTITLE 1, Promote model to champion (optional)
if promote_to_champion:
    print(f"\n{'='*60}")
    print(f"Promoting model version {model_version} to 'champion' alias")
    print(f"{'='*60}\n")
    
    try:
        # Check if champion alias already exists
        try:
            current_champion = client.get_model_version_by_alias(model_name, "champion")
            print(f"Current champion version: {current_champion.version}")
        except Exception:
            print("No current champion version found")
        
        # Set the new champion
        client.set_registered_model_alias(model_name, "champion", model_version)
        print(f"✓ Model version {model_version} promoted to 'champion' alias")
        
        # Update model description
        model_version_details = client.get_model_version(model_name, model_version)
        current_description = model_version_details.description or ""
        
        deployment_info = f"\n\n---\n\nPromotion to Champion: SUCCESS\nEnvironment: {env}\nEndpoint: {serving_endpoint_name}\nTimestamp: {time.strftime('%Y-%m-%d %H:%M:%S UTC', time.gmtime())}"
        
        client.update_model_version(
            name=model_name,
            version=model_version,
            description=current_description + deployment_info
        )
        print("✓ Model description updated with deployment information")
        
    except Exception as e:
        print(f"✗ Error promoting model to champion: {e}")
        raise
else:
    print("\nSkipping promotion to champion (promote_to_champion=false)")


In [None]:
# DBTITLE 1, Test the deployed endpoint (optional)
if test_endpoint:
    print(f"\n{'='*60}")
    print("Testing deployed endpoint")
    print(f"{'='*60}\n")
    
    import pandas as pd
    from datetime import datetime, timedelta
    
    # Create sample test data
    # For Prophet model, we need historical data points with 'ds' and 'y' columns
    end_date = datetime.now().date()
    start_date = end_date - timedelta(days=30)
    
    test_data = pd.DataFrame({
        'ds': pd.date_range(start=start_date, end=end_date, freq='D'),
        'y': [100 + i * 2 + (i % 7) * 5 for i in range(31)]  # Synthetic data with trend and weekly pattern
    })
    
    print("Sample input data:")
    print(test_data.head())
    print(f"\nTotal records: {len(test_data)}")
    
    try:
        # Query the endpoint
        print("\nQuerying endpoint...")
        
        # Convert dataframe to the format expected by the endpoint
        dataframe_records = test_data.to_dict(orient='split')
        
        response = w.serving_endpoints.query(
            name=serving_endpoint_name,
            dataframe_records=dataframe_records,
        )
        
        print("\n✓ Endpoint test successful!")
        print("\nSample predictions:")
        
        # Parse and display predictions
        if hasattr(response, 'predictions'):
            predictions_df = pd.DataFrame(response.predictions)
            print(predictions_df.head(10))
            print(f"\nTotal forecast points: {len(predictions_df)}")
        else:
            print(response)
            
    except Exception as e:
        print(f"\n✗ Endpoint test failed: {e}")
        print("\nNote: The endpoint is deployed but the test query failed.")
        print("This might be due to the specific input format expected by your model.")
        print("Please verify the endpoint manually or adjust the test data format.")
else:
    print("\nSkipping endpoint testing (test_endpoint=false)")


In [None]:
# DBTITLE 1, Set task values and return deployment info
# Get endpoint details
endpoint = w.serving_endpoints.get(serving_endpoint_name)
endpoint_url = f"{w.config.host}/serving-endpoints/{serving_endpoint_name}/invocations"

deployment_info = {
    "endpoint_name": serving_endpoint_name,
    "endpoint_url": endpoint_url,
    "model_name": model_name,
    "model_version": model_version,
    "model_uri": model_uri,
    "promoted_to_champion": promote_to_champion,
    "environment": env,
}

# Set task values for downstream tasks
dbutils.jobs.taskValues.set("endpoint_name", serving_endpoint_name)
dbutils.jobs.taskValues.set("endpoint_url", endpoint_url)
dbutils.jobs.taskValues.set("model_version_deployed", model_version)
dbutils.jobs.taskValues.set("champion_promoted", str(promote_to_champion))

print(f"\n{'='*60}")
print("DEPLOYMENT SUMMARY")
print(f"{'='*60}")
print(f"Endpoint Name: {serving_endpoint_name}")
print(f"Endpoint URL: {endpoint_url}")
print(f"Model: {model_name} (version {model_version})")
print(f"Model URI: {model_uri}")
print(f"Promoted to Champion: {promote_to_champion}")
print(f"Environment: {env}")
print(f"Status: ✓ DEPLOYED SUCCESSFULLY")
print(f"{'='*60}")

# Exit with deployment info JSON
import json
dbutils.notebook.exit(json.dumps(deployment_info))
