In [0]:
# ===================================================================
# CELL 1: Install Packages (Development)
# ===================================================================

%pip install \
    mlflow[databricks]>=2.16.0 \
    databricks-agents \
    databricks-openai \
    openai \
    pydantic \
    unitycatalog-ai \
    uv \
    --upgrade --quiet

dbutils.library.restartPython()

## Clean up unsuccessfully deployed versions

In [0]:
# ===================================================================
# CELL 2: Clean Up Old Versions (Optional)
# ===================================================================

from mlflow.tracking import MlflowClient

client = MlflowClient()
UC_MODEL_NAME = "dev_kiddo.silver.CareGapsModel"

# Delete failed conda versions
failed_versions = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,19,20,21,22,23,24,25,26]

for version in failed_versions:
    try:
        client.delete_model_version(UC_MODEL_NAME, str(version))
        print(f"✓ Deleted version {version}")
    except Exception as e:
        print(f"  Version {version}: already deleted or doesn't exist")

print("\n✓ Cleanup complete")

### Run this cell only if you want to show synthetic data

In [0]:
import os 
os.environ["CAREGAPS_DATA_MODE"] = "demo"  # or "real"

# LOG Model with pre-deployent check

In [0]:
# ===================================================================
# VALIDATED DEPLOYMENT WITH PRE-DEPLOYMENT CHECK
# ===================================================================

import mlflow
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint

mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = "dev_kiddo.silver.CareGapsModel"
ENDPOINT_NAME = "agents_dev_kiddo-silver-CareGapsModel"
LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"

# Find agent.py
import os
if os.path.exists("agent.py"):
    agent_path = "agent.py"
else:
    agent_path = "/Workspace/Users/adminjkhan@akronchildrens.org/CareGaps/CareGapsAgents/agent.py"

print(f"Agent path: {agent_path}")
print(f"File exists: {os.path.exists(agent_path)}")

# Build resources for auth passthrough (UC functions + serving endpoint)
UC_TOOL_NAMES = [
    "dev_kiddo.silver.get_top_providers",
    "dev_kiddo.silver.get_patient_360",
    "dev_kiddo.silver.get_gap_categories",
    "dev_kiddo.silver.get_provider_gaps",
    "dev_kiddo.silver.get_long_open_gaps",
    "dev_kiddo.silver.get_outreach_needed",
    "dev_kiddo.silver.get_appointments_with_gaps",
    "dev_kiddo.silver.get_critical_gaps",
    "dev_kiddo.silver.search_patients",
    "dev_kiddo.silver.get_gaps_by_type",
    "dev_kiddo.silver.get_gap_statistics",
    "dev_kiddo.silver.get_department_summary",
    "dev_kiddo.silver.get_gaps_by_age",
    "dev_kiddo.silver.get_gaps_no_appointments",
    "dev_kiddo.silver.get_patient_gaps",
    "dev_kiddo.silver.get_campaign_statistics",
    "dev_kiddo.silver.search_campaign_opportunities",
    "dev_kiddo.silver.get_campaign_opportunities",
    "dev_kiddo.silver.get_patient_campaign_history",
]

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool_name in UC_TOOL_NAMES:
    resources.append(DatabricksFunction(function_name=tool_name))

print(f"{'='*60}")
print("STEP 1: LOG MODEL")
print('='*60)

if mlflow.active_run():
    mlflow.end_run()

with mlflow.start_run(run_name="validated_deployment"):
    
    model_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model=agent_path,
        registered_model_name=UC_MODEL_NAME,
        pip_requirements=[
            "mlflow[databricks]>=2.16.0",
            "databricks-openai>=0.2.0",
            "openai>=1.0.0",
            "pydantic>=2.0.0",
            "unitycatalog-ai>=0.1.0",
        ],
        resources=resources,
        input_example=None,
        signature=None,
    )
    
    run_id = model_info.run_id
    version = model_info.registered_model_version
    
    print(f"✓ Model logged")
    print(f"  Version: {version}")
    print(f"  Run ID: {run_id}")
    print(f"  Resources: {len(resources)} (1 endpoint + {len(UC_TOOL_NAMES)} UC functions)")

## Run validation test to see if model is returning responses

In [0]:
#===================================================================                                                                                                                                                                                                 
# DEBUG: Inspect raw response structure (run this ONCE to see format)

#===================================================================
import mlflow

from mlflow.tracking import MlflowClient                                             

mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

model_versions = client.search_model_versions("name='dev_kiddo.silver.CareGapsModel'")
latest = max(model_versions, key=lambda x: int(x.version))
run_id = latest.run_id
model_uri = f"runs:/{run_id}/agent"

print(f"Model version {latest.version}, run: {run_id}")
print("Loading model...\n")

loaded_model = mlflow.pyfunc.load_model(model_uri)
response = loaded_model.predict(
    {"input": [{"role": "user", "content": "How many care gaps are there?"}]}
)

print(f"Type: {type(response)}")

# The response is a list of dicts — extract text from message items
def extract_text(response):
    """Extract text from list-of-dicts response format."""
    if response is None:
        return ""
    items = response
    # If wrapped in an object with .output
    if hasattr(response, "output"):
        items = response.output
    # If wrapped in a dict with "output"
    elif isinstance(response, dict) and "output" in response:
        items = response["output"]
    # If it's not a list at this point, fallback
    if not isinstance(items, list):
        return str(response)
    # Extract text from message items
    parts = []
    for item in items:
        if isinstance(item, dict) and item.get("type") == "message":
            for content in item.get("content", []):
                if isinstance(content, dict) and "text" in content:
                    parts.append(content["text"])
        elif hasattr(item, "type") and getattr(item, "type") == "message":
            for content in getattr(item, "content", []):
                if hasattr(content, "text"):
                    parts.append(content.text)
    return "\n".join(parts)

text = extract_text(response)
print(f"Extracted: {len(text)} chars")
print(f"Preview: {text[:300]}")

In [0]:
#===================================================================
# PRE-DEPLOYMENT VALIDATION SUITE
# Tests decision-tree routing, response format, scope, and campaigns
#===================================================================                                                                                            
import mlflow
import time
import json
import signal
from mlflow.tracking import MlflowClient

mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

# Get latest model
model_versions = client.search_model_versions("name='dev_kiddo.silver.CareGapsModel'")
latest = max(model_versions, key=lambda x: int(x.version))
run_id = latest.run_id
model_uri = f"runs:/{run_id}/agent"

TEST_TIMEOUT = 300  # 5 minutes per test

print(f"Testing model version {latest.version} (run: {run_id})")
print("Loading model in-process...")

loaded_model = mlflow.pyfunc.load_model(model_uri)

print(f"Model loaded. Timeout per test: {TEST_TIMEOUT}s")
print(f"{'='*70}\n")


# -------------------------------------------------------------------
# Timeout handler
# -------------------------------------------------------------------
class TestTimeout(Exception):
    pass

def _timeout_handler(signum, frame):
    raise TestTimeout(f"Test exceeded {TEST_TIMEOUT}s timeout")


# -------------------------------------------------------------------
# Extract text from response (list-of-dicts format)
# -------------------------------------------------------------------
def extract_text(response) -> str:
    """Extract text from message items in the response."""
    if response is None:
        return ""
    if isinstance(response, str):
        return response

    # Normalize to a list of output items
    items = response
    if hasattr(response, "output"):
        items = response.output
    elif isinstance(response, dict) and "output" in response:
        items = response["output"]

    if not isinstance(items, list):
        return str(response)

    # Extract text from message items only
    parts = []
    for item in items:
        # Dict format: {"type": "message", "content": [{"text": "..."}]}
        if isinstance(item, dict) and item.get("type") == "message":
            for content in item.get("content", []):
                if isinstance(content, dict) and "text" in content:
                    parts.append(content["text"])
        # Pydantic/object format: item.type == "message", item.content[].text
        elif hasattr(item, "type") and getattr(item, "type") == "message":
            for content in getattr(item, "content", []):
                if hasattr(content, "text"):
                    parts.append(content.text)

    return "\n".join(parts)


# -------------------------------------------------------------------
# Run a single test with timeout
# -------------------------------------------------------------------
def run_test(query: str) -> dict:
    start = time.time()
    old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
    signal.alarm(TEST_TIMEOUT)
    try:
        response = loaded_model.predict(
            {"input": [{"role": "user", "content": query}]}
        )
        signal.alarm(0)
        elapsed = time.time() - start
        text = extract_text(response)
        return {"text": text, "elapsed": elapsed, "error": None}
    except TestTimeout:
        elapsed = time.time() - start
        return {"text": "", "elapsed": elapsed, "error": f"TIMEOUT after {TEST_TIMEOUT}s"}
    except Exception as e:
        signal.alarm(0)
        elapsed = time.time() - start
        return {"text": "", "elapsed": elapsed, "error": str(e)}
    finally:
        signal.signal(signal.SIGALRM, old_handler)


# -------------------------------------------------------------------
# Validation checks
# -------------------------------------------------------------------
def check_has_content(text, min_len=50):
    ok = len(text) >= min_len
    return ok, f"Response length {len(text)} chars (min {min_len})"

def check_has_table(text):
    ok = "|" in text and "---" in text
    return ok, "Contains markdown table" if ok else "Missing markdown table"

def check_has_next_actions(text):
    ok = "next best action" in text.lower() or "### next" in text.lower()
    return ok, "Has Next Best Actions" if ok else "Missing Next Best Actions"

def check_contains_any(text, keywords, label=""):
    found = [kw for kw in keywords if kw.lower() in text.lower()]
    ok = len(found) > 0
    desc = f"Found: {found}" if ok else f"None of {keywords} found"
    return ok, f"{label}: {desc}" if label else desc

def check_not_contains(text, keywords, label=""):
    found = [kw for kw in keywords if kw.lower() in text.lower()]
    ok = len(found) == 0
    desc = "Correctly absent" if ok else f"Unexpectedly found: {found}"
    return ok, f"{label}: {desc}" if label else desc

def check_redirects_to_dashboard(text):
    ok = any(kw in text.lower() for kw in ["campaign", "sidebar", "dashboard", "navigate"])
    return ok, "Redirects to dashboard" if ok else "Missing dashboard redirect"


# ===================================================================
# TEST DEFINITIONS (14 tests, 9 categories)
# ===================================================================
tests = []

# --- Step 1: CAMPAIGN ROUTING ---
tests.append({
    "id": "C1", "category": "Campaign Routing",
    "query": "How is the flu vaccine piggybacking campaign going?",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_contains_any(t, ["pending", "sent", "approved", "opportunities", "total", "campaign"], "Campaign data"),
        lambda t: check_not_contains(t, ["gap_type", "care gap type", "well child"], "No care-gap leakage"),
    ],
})
tests.append({
    "id": "C2", "category": "Campaign Routing",
    "query": "Show flu vaccine opportunities with asthma patients",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_contains_any(t, ["asthma", "j45"], "Asthma reference"),
    ],
})
tests.append({
    "id": "C3", "category": "Campaign Routing",
    "query": "Search for sibling vaccine opportunities at Beachwood",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_contains_any(t, ["beachwood", "opportunity", "sibling", "campaign", "vaccine"], "Location search"),
    ],
})

# --- Step 2: PATIENT ROUTING ---
tests.append({
    "id": "P1", "category": "Patient Routing",
    "query": "Find patient 2886348",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_contains_any(t, ["2886348", "patient", "mrn"], "Patient reference"),
    ],
})
tests.append({
    "id": "P2", "category": "Patient Routing",
    "query": "Show me all gaps for patient 2886348",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_contains_any(t, ["gap", "2886348"], "Patient gaps data"),
    ],
})

# --- Step 3: DEPARTMENT / PROVIDER ---
tests.append({
    "id": "D1", "category": "Department Routing",
    "query": "Which departments have the most open care gaps?",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_contains_any(t, ["department", "dept"], "Department data"),
        lambda t: check_not_contains(t, ["vaccine", "flu", "campaign"], "No campaign leakage"),
    ],
})
tests.append({
    "id": "D2", "category": "Provider Routing",
    "query": "Show me the top providers by care gap count",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_contains_any(t, ["provider", "dr", "md"], "Provider data"),
    ],
})

# --- Step 4: GAP TYPES ---
tests.append({
    "id": "G1", "category": "Gap Types",
    "query": "What are the top gap types by volume?",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_has_next_actions(t),
    ],
})

# --- Step 5: URGENCY / OUTREACH ---
tests.append({
    "id": "U1", "category": "Urgency / Outreach",
    "query": "Show me critical care gaps that need immediate attention",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_contains_any(t, ["critical", "urgent", "overdue", "high"], "Urgency indicators"),
    ],
})
tests.append({
    "id": "U2", "category": "Urgency / Outreach",
    "query": "Show patients with gaps but no upcoming appointments",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_contains_any(t, ["no appointment", "appointment", "patient"], "No-appointment data"),
    ],
})

# --- Step 6: GENERAL OVERVIEW ---
tests.append({
    "id": "O1", "category": "General Overview",
    "query": "How many care gaps are there overall?",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_contains_any(t, ["total", "statistic", "gap", "open", "count"], "Statistics data"),
    ],
})

# --- SCOPE BOUNDARY ---
tests.append({
    "id": "S1", "category": "Scope Boundary",
    "query": "What is the weather in Akron today?",
    "checks": [
        lambda t: check_contains_any(t, ["care gap", "caregaps", "only help", "cannot", "can only"], "Scope rejection"),
    ],
})

# --- DASHBOARD REDIRECT ---
tests.append({
    "id": "R1", "category": "Dashboard Redirect",
    "query": "Approve the flu vaccine opportunity for MRN 12345",
    "checks": [
        lambda t: check_redirects_to_dashboard(t),
    ],
})

# --- RESPONSE FORMAT ---
tests.append({
    "id": "F1", "category": "Response Format",
    "query": "Show flu vaccine opportunities that are pending",
    "checks": [
        lambda t: check_has_content(t),
        lambda t: check_has_table(t),
        lambda t: check_has_next_actions(t),
        lambda t: check_not_contains(t, ["function_call", "tool_call"], "No raw JSON leaked"),
    ],
})


# ===================================================================
# RUN ALL TESTS
# ===================================================================
results = []
total_pass = 0
total_fail = 0

for test in tests:
    tid = test["id"]
    cat = test["category"]
    query = test["query"]

    print(f"[{tid}] {cat}")
    print(f"  Query: {query}")

    result = run_test(query)

    if result["error"]:
        print(f"  ERROR: {result['error']} ({result['elapsed']:.1f}s)")
        total_fail += 1
        results.append({"id": tid, "category": cat, "status": "FAIL", "reason": result["error"]})
        print()
        continue

    text = result["text"]
    elapsed = result["elapsed"]
    print(f"  Response: {len(text)} chars in {elapsed:.1f}s")
    print(f"  Preview: {text[:150]}...")

    check_results = []
    test_passed = True
    for check_fn in test["checks"]:
        ok, msg = check_fn(text)
        symbol = "PASS" if ok else "FAIL"
        check_results.append((ok, msg))
        print(f"    [{symbol}] {msg}")
        if not ok:
            test_passed = False

    if test_passed:
        total_pass += 1
        results.append({"id": tid, "category": cat, "status": "PASS", "elapsed": elapsed})
    else:
        total_fail += 1
        failed_checks = [msg for ok, msg in check_results if not ok]
        results.append({"id": tid, "category": cat, "status": "FAIL", "reason": "; ".join(failed_checks)})

    print()


# ===================================================================
# SUMMARY
# ===================================================================
print("=" * 70)
print("VALIDATION SUMMARY")
print("=" * 70)
print(f"\n  Total:   {len(tests)}")
print(f"  Passed:  {total_pass}")
print(f"  Failed:  {total_fail}")
print(f"  Rate:    {100*total_pass/len(tests):.0f}%\n")

if total_fail > 0:
    print("FAILED TESTS:")
    for r in results:
        if r["status"] == "FAIL":
            print(f"  [{r['id']}] {r['category']}: {r.get('reason', '')}")
    print()

if total_fail == 0:
    print(">>> ALL TESTS PASSED - SAFE TO DEPLOY <<<")
else:
    print(f">>> {total_fail} TEST(S) FAILED - REVIEW BEFORE DEPLOYING <<<")

print("=" * 70)


## Deploy the agent as serving endpoint

In [0]:
# ============================================================
# STEP 3: DEPLOY TO SERVING ENDPOINT
# ============================================================

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput
from mlflow.tracking import MlflowClient

mlflow.set_registry_uri("databricks-uc")

# Configuration
UC_MODEL_NAME = "dev_kiddo.silver.CareGapsModel"
ENDPOINT_NAME = "agents_dev_kiddo-silver-CareGapsModel"

print(f"{'='*60}")
print("DEPLOYING AGENT TO SERVING ENDPOINT")
print('='*60)

# Get latest registered version
client = MlflowClient()
model_versions = client.search_model_versions(f"name='{UC_MODEL_NAME}'")
latest_version = max(model_versions, key=lambda x: int(x.version))

print(f"\nModel: {UC_MODEL_NAME}")
print(f"Version: {latest_version.version}")
print(f"Endpoint: {ENDPOINT_NAME}\n")

# Create WorkspaceClient
w = WorkspaceClient()

# Create served entity using SDK class
served_entity = ServedEntityInput(
    entity_name=UC_MODEL_NAME,
    entity_version=str(latest_version.version),
    workload_size="Small",
    scale_to_zero_enabled=False,
)

print("Deploying...")

try:
    # Try to update existing endpoint
    w.serving_endpoints.update_config(
        name=ENDPOINT_NAME,
        served_entities=[served_entity],
    )
    print(f"✓ Updated existing endpoint")
    
except Exception as e:
    error_msg = str(e)
    
    if "RESOURCE_DOES_NOT_EXIST" in error_msg or "does not exist" in error_msg.lower():
        # Create new endpoint if it doesn't exist
        print("Endpoint doesn't exist, creating...")
        
        w.serving_endpoints.create(
            name=ENDPOINT_NAME,
            config=EndpointCoreConfigInput(
                name=ENDPOINT_NAME,
                served_entities=[served_entity]
            )
        )
        print(f"✓ Created new endpoint")
    else:
        print(f"✗ Deployment failed: {error_msg}")
        raise

print(f"\n{'='*60}")
print("✓✓✓ DEPLOYMENT INITIATED ✓✓✓")
print('='*60)
print(f"\nEndpoint: {ENDPOINT_NAME}")
print(f"Version: {latest_version.version}")
print(f"\nContainer build will take 3-5 minutes...")
print(f"Monitor at: Serving > Endpoints > {ENDPOINT_NAME}")