# MedGemma Grounding Agent Demo

This notebook demonstrates the LangGraph ReAct agent for grounding clinical trial criteria using MedGemma 1.5 4B with UMLS MCP tools.

## Setup

In [None]:
# Install dependencies
!pip install -q langchain-core langgraph langchain-huggingface transformers accelerate bitsandbytes jinja2 pydantic

In [None]:
import os
import sys
from pathlib import Path

try:
    from kaggle_secrets import UserSecretsClient
except ImportError:
    UserSecretsClient = None

# Add project to path
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root / "components" / "grounding-service" / "src"))
sys.path.insert(0, str(project_root / "components" / "shared" / "src"))

# Set environment variables
os.environ["USE_AI_GROUNDING"] = "true"
os.environ["MEDGEMMA_QUANTIZATION"] = "4bit"

# Set UMLS API key (use Kaggle secrets in production)
if UserSecretsClient is not None:
    try:
        secrets = UserSecretsClient()
        os.environ["UMLS_API_KEY"] = secrets.get_secret("UMLS_API_KEY")
        os.environ["HF_TOKEN"] = secrets.get_secret("HF_TOKEN")
    except Exception:
        print("⚠️ Secrets not available - set UMLS_API_KEY and HF_TOKEN manually")
else:
    print("⚠️ kaggle_secrets not available - set UMLS_API_KEY and HF_TOKEN manually")


## Load Grounding Agent

In [None]:
from grounding_service.agent import GroundingAgent

# Initialize agent with 4-bit quantization for Kaggle T4 GPU
agent = GroundingAgent(
    model_path="google/medgemma-1.5-4b-it",
    quantization="4bit"
)

print("✅ Agent loaded successfully")

## Demo: Ground Clinical Trial Criteria

In [None]:

# Example 1: Age criterion
criterion1 = "Age >= 18 years"
result1 = await agent.ground(criterion1, "inclusion")

print("Criterion:", criterion1)
print("\nSNOMED Codes:", result1.snomed_codes)
print("\nField Mappings:")
for mapping in result1.field_mappings:
    print(f"  - {mapping.field} {mapping.relation} {mapping.value} (confidence: {mapping.confidence:.2f})")
print("\nReasoning:", result1.reasoning[:200] + "..." if len(result1.reasoning) > 200 else result1.reasoning)

In [None]:
# Example 2: BMI criterion
criterion2 = "BMI >= 30 kg/m²"
result2 = await agent.ground(criterion2, "inclusion")

print("Criterion:", criterion2)
print("\nSNOMED Codes:", result2.snomed_codes)
print("\nField Mappings:")
for mapping in result2.field_mappings:
    print(f"  - {mapping.field} {mapping.relation} {mapping.value} (confidence: {mapping.confidence:.2f})")
    if mapping.umls_cui:
        print(f"    UMLS CUI: {mapping.umls_cui}")

In [None]:
# Example 3: Complex criterion
criterion3 = "ECOG performance status <= 2"
result3 = await agent.ground(criterion3, "inclusion")

print("Criterion:", criterion3)
print("\nSNOMED Codes:", result3.snomed_codes)
print("\nField Mappings:")
for mapping in result3.field_mappings:
    print(f"  - {mapping.field} {mapping.relation} {mapping.value} (confidence: {mapping.confidence:.2f})")

## Comparison: AI vs Baseline

Compare AI-powered grounding with the regex baseline.

In [None]:
from grounding_service import umls_client

test_criterion = "HbA1c < 7.0%"

# Baseline (regex)
baseline_mappings = umls_client.propose_field_mapping(test_criterion)
print("Baseline (regex) mappings:", baseline_mappings)

# AI agent
ai_result = await agent.ground(test_criterion, "inclusion")
print("\nAI agent mappings:")
for mapping in ai_result.field_mappings:
    print(f"  - {mapping.field} {mapping.relation} {mapping.value}")
    print(f"    Confidence: {mapping.confidence:.2f}, UMLS CUI: {mapping.umls_cui}")