In [1]:
from langchain_google_community import BigQueryLoader
from google.cloud import bigquery
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage
from langchain_google_vertexai import ChatVertexAI
from langgraph.graph import Graph, END
from typing import TypedDict, Annotated, Union, Optional
import pandas as pd
import json
import os
import re
import re
import warnings
warnings.filterwarnings("ignore")

# Set up credentials
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = "/Users/mncedisimncwabe/Downloads/hallowed-span-459710-s1-c41d79c9b56b.json"

# Initialize BigQuery client
client = bigquery.Client()

def get_all_schemas(dataset_id):
    """Fetch schemas for all tables in the dataset"""
    tables = client.list_tables(dataset_id)
    schemas = {}
    
    for table in tables:
        try:
            table_ref = client.get_table(table)
            schemas[table.table_id] = [
                {
                    "name": field.name,
                    "type": field.field_type,
                    "description": field.description or "No description"
                }
                for field in table_ref.schema
            ]
        except Exception as e:
            print(f"Error fetching schema for {table.table_id}: {str(e)}")
    
    return schemas

# Get all schemas
schemas = get_all_schemas("test_clustering")

# Update these
BQ_PROJECT = "hallowed-span-459710-s1" 
BQ_DATASET = "test_clustering"   
LOCATION =  "us-central1"  
TARGET_TABLES = {                      
    "user-engagement": "User engagement data",
    "dim_date": "Date dimension table",
    "fact_user_metrics": "Aggregated user metrics"
}

# Initialize Vertex AI LLM
llm = ChatVertexAI(
    model_name="gemini-2.0-flash-001",
    temperature=0, # Controls the randomness/creativity of the AI's output. With higher values, it might sometimes add unnecessary clauses.
    max_output_tokens=2048, # Set the maximum length of the generated response in tokens (≈ words/word parts). 1,500-2,000 words
    project=BQ_PROJECT,
    location=LOCATION
)

def format_schema_for_prompt(schema_data):
    """Format schema data for human-readable prompt"""
    formatted = []
    for table_name, columns in schema_data.items():
        formatted.append(f"Table {table_name}:")
        for col in columns:
            formatted.append(f"  - {col['name']} ({col['type']}): {col['description']}")
    return "\n".join(formatted)

def get_table_reference(table_name: str) -> str:
    """Generate properly formatted BigQuery table reference"""
    return f"`{BQ_PROJECT}.{BQ_DATASET}.{table_name}`"

# agent instructions prompt 
system_prompt = f"""You are an advanced BigQuery SQL expert with data modeling intuition. Key capabilities:

1. Schema Reasoning:
- Automatically detect date fields that should join to dim_date (e.g., first_seen_date → dim_date.date)
- Recognize common patterns (user_id for joins, *_date for time dimensions)
- Identify fact vs dimension tables based on structure

2. Intelligent Defaults:
- For time-based questions, default to appropriate date granularity (month/quarter/year)
- For user metrics, consider both raw (user-engagement) and aggregated (fact_user_metrics) sources
- When counting distinct values, automatically add LIMIT based on expected cardinality

3. Self-Correction:
- If initial query returns unexpected zeros/null values:
  1. Check date formatting
  2. Verify join conditions
  3. Consider alternative source tables

4. Analytical Best Practices:
- Prefer COUNT(DISTINCT) over COUNT() for user metrics
- Use appropriate date functions (EXTRACT, DATE_TRUNC)
- Apply CASE WHEN for conditional logic

Available tables:
{format_schema_for_prompt(schemas)}

Examples of Intelligent Behavior:
Q: "How many unique months per user?"
A: SELECT 
     u.user_id,
     COUNT(DISTINCT FORMAT_DATE('%Y-%m', d.date)) AS unique_months
   FROM {get_table_reference('user-engagement')} u
   JOIN {get_table_reference('dim_date')} d 
     ON u.first_seen_date = d.date
   GROUP BY u.user_id

Q: "Find users active in Q2 but not Q3"
A: WITH q2_users AS (
     SELECT DISTINCT user_id 
     FROM {get_table_reference('user-engagement')} u
     JOIN {get_table_reference('dim_date')} d 
       ON u.first_seen_date = d.date
     WHERE d.quarter = 'Q2'
   )
   SELECT q2.user_id
   FROM q2_users q2
   WHERE NOT EXISTS (
     SELECT 1 
     FROM {get_table_reference('user-engagement')} u2
     JOIN {get_table_reference('dim_date')} d2 
       ON u2.first_seen_date = d2.date
     WHERE d2.quarter = 'Q3'
     AND u2.user_id = q2.user_id
   )
"""

prompt_template = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    ("human", "{question}")
])

# Define the agent state
class AgentState(TypedDict):
    question: str
    sql_query: Optional[str]
    query_result: Union[pd.DataFrame, str, None]
    validation_errors: list[str]
    attempts: int
    needs_correction: bool
    human_approved: Optional[bool]
    human_feedback: Optional[str]
    needs_modification: Optional[bool]

# Helper function to clean SQL query
def clean_sql_query(query: str) -> str:
    """Clean and normalize SQL query by removing backticks, etc."""
    if query is None:
        return ""
    if query.startswith("```") and query.endswith("```"):
        query = query.strip("`")
        if query.lower().startswith("sql"):
            query = query[3:].strip()
    return query.strip()

# Add new helper function for human interaction
def get_human_approval(state: AgentState) -> AgentState:
    """Get human approval for generated SQL with ability to provide feedback"""
    print("\n" + "="*50)
    print("HUMAN APPROVAL REQUIRED")
    print("Generated SQL Query:")
    print(state["sql_query"])
    
    question = state.get("question", "Unknown question")
    print("\nQuestion being answered:", question)
    
    # Allow for human nuanced feedback
    print("\nYou can type:")
    print("- 'Y' or 'Yes' to approve")
    print("- 'N' or 'No' to reject")
    print("- 'Yes but...' followed by your feedback/suggestions to modify the query")
    
    feedback = input("Your response: ").strip()
    
    # Process feedback
    feedback_lower = feedback.lower()
    if feedback_lower.startswith('y') or feedback_lower.startswith('yes'):
        is_approved = True
        # Check if there's feedback beyond just approval
        if 'but' in feedback_lower:
            # Extract the modification suggestion
            suggestion = feedback[feedback.lower().find('but') + 3:].strip()
            return {
                **state,
                "human_approved": True,
                "human_feedback": suggestion,
                "needs_modification": True
            }
        else:
            return {**state, "human_approved": True, "needs_modification": False}
    else:
        # Rejection with potential feedback
        suggestion = feedback[2:].strip() if len(feedback) > 2 else ""
        return {
            **state,
            "human_approved": False,
            "human_feedback": suggestion if suggestion else "Query rejected",
            "needs_modification": False  # Will go to correction process
        }

# Initialize the graph
workflow = Graph()

# Define nodes
def generate_sql(state: AgentState) -> AgentState:
    # LLM generates SQL query
    print(f"\nGenerating SQL for: {state['question']}")
    chain = (
        {"question": lambda x: x["question"]}
        | prompt_template
        | llm
        | StrOutputParser()
    )
    sql_query = chain.invoke(state)
    print(f"Generated SQL (raw): {sql_query}")
    return {
        **state,  # Preserve all existing state fields, including question
        "sql_query": sql_query, 
        "attempts": state.get("attempts", 0) + 1,
        "query_result": None, 
        "validation_errors": [] 
    }

def validate_sql(state: AgentState) -> AgentState:
    query = clean_sql_query(state["sql_query"])
    errors = []
    
    table_ref_pattern = re.compile(
        r"(`" + re.escape(BQ_PROJECT) + r"\." + re.escape(BQ_DATASET) + r"\.[a-zA-Z0-9_-]+`|`[a-zA-Z0-9_-]+`)", 
        re.IGNORECASE
    )
    
    # Validation checks
    if not query.lower().startswith(("select", "with")):
        errors.append("Must start with SELECT/WITH")
    
    if not table_ref_pattern.search(query):
        errors.append(f"Missing valid table reference (expected format: `{BQ_PROJECT}.{BQ_DATASET}.table_name` or `table_name`)")
    
    if "join" in query.lower() and not re.search(r"\bjoin\b(.|\n)+?\bon\b", query, re.IGNORECASE):
        errors.append("JOIN missing ON clause")
    
    return {
        **state,  
        "validation_errors": errors, 
        "sql_query": state["sql_query"],
        "query_result": None  
    }

def execute_query(state: AgentState) -> AgentState:
    if state["validation_errors"]:
        print(f"Validation errors: {state['validation_errors']}")
        return {
            **state,  
            "query_result": f"Validation errors: {', '.join(state['validation_errors'])}",
            "sql_query": state["sql_query"],
            "needs_correction": True
        }
    
    query = clean_sql_query(state["sql_query"])
    print(f"Executing query: {query}")
    try:
        query_job = client.query(query)
        result = query_job.result().to_dataframe()
        return {
            **state, 
            "query_result": result,
            "sql_query": state["sql_query"],
            "needs_correction": False
        }
    except Exception as e:
        return {
            **state,  
            "query_result": f"Execution Error: {str(e)}",
            "sql_query": state["sql_query"],
            "needs_correction": True
        }

def analyze_results(state: AgentState) -> AgentState:
    result_update = {
        **state, 
        "sql_query": state["sql_query"],
        "query_result": state["query_result"]
    }
    
    if isinstance(state["query_result"], str):
        print(f"Problem detected: {state['query_result']}")
        result_update["needs_correction"] = True
        return result_update
    
    if isinstance(state["query_result"], pd.DataFrame):
        if state["query_result"].empty:
            result_update["needs_correction"] = True
            result_update["query_result"] = "Query returned empty results"
        elif (state["query_result"].iloc[:, 1:] == 0).all().all():
            result_update["needs_correction"] = True
            result_update["query_result"] = "Query returned all zeros"
        else:
            result_update["needs_correction"] = False
    
    return result_update

def correct_query(state: AgentState) -> AgentState:
    """Enhanced correction with human feedback context"""
    error_context = state.get("query_result", "Unknown error")
    
    # Special case for human disapproval
    if state.get("human_approved") is False:
        error_context = "Human reviewer rejected the generated SQL query"
    
    original_query = state.get("sql_query", "No query generated yet")
    print(f"\nAttempting to correct query. Error: {error_context}")
    
    question = state.get("question", "Unknown question") 
    
    correction_prompt = f"""Correct this SQL query based on the feedback:
    
    Error Context: {error_context}
    Original Query: {original_query}
    
    User Question: {question}
    
    Provide ONLY the corrected SQL query:"""
    
    corrected = llm.invoke(correction_prompt)
    return {
        **state,  
        "sql_query": corrected.content.strip(),
        "attempts": state.get("attempts", 0),
        "query_result": None,
        "validation_errors": [],
        "human_approved": None  # Reset approval state
    }

# Add a new node for handling human modification suggestions
def modify_query_based_on_feedback(state: AgentState) -> AgentState:
    """Modify query based on human feedback"""
    original_query = state["sql_query"]
    feedback = state.get("human_feedback", "")
    
    print(f"\nModifying query based on feedback: {feedback}")
    
    # Construct prompt for the LLM to modify the query
    modification_prompt = f"""
    I have a SQL query that needs to be modified based on human feedback.
    
    Original SQL query:
    {original_query}
    
    Human feedback: {feedback}
    
    Please modify the query according to this feedback. Return ONLY the modified SQL query.
    """
    
    # Use the LLM to modify the query
    modified_query = llm.invoke(modification_prompt)
    
    return {
        **state,
        "sql_query": modified_query.content.strip(),
        "needs_modification": False,  # Reset flag after modification
        "human_approved": True  # Consider it approved after modification
    }

# Add nodes to workflow
workflow.add_node("generate", generate_sql)
workflow.add_node("validate", validate_sql)
workflow.add_node("execute", execute_query)
workflow.add_node("analyze", analyze_results)
workflow.add_node("correct", correct_query)
workflow.add_node("human_approval", get_human_approval)
workflow.add_node("modify", modify_query_based_on_feedback)

# Set up workflow edges
workflow.set_entry_point("generate")
workflow.add_edge("generate", "validate")
workflow.add_edge("validate", "human_approval") 

# Add conditional edges
workflow.add_conditional_edges(
    "human_approval",
    lambda x: "modify" if x.get("human_approved") and x.get("needs_modification", False) else 
              "execute" if x.get("human_approved") else "correct",
    {"modify": "modify", "execute": "execute", "correct": "correct"}
)

workflow.add_edge("modify", "execute")
workflow.add_edge("execute", "analyze")

workflow.add_conditional_edges(
    "analyze",
    lambda x: "correct" if x.get("needs_correction", False) and x.get("attempts", 0) < 3 else END,
    {"correct": "correct", END: END}
)

workflow.add_edge("correct", "validate")

# Compile the graph
app = workflow.compile()

def bigquery_agent(question: str, max_attempts: int = 3) -> Union[pd.DataFrame, str]:
    state = {
        "question": question,
        "sql_query": None,
        "query_result": None,
        "validation_errors": [],
        "attempts": 0,
        "needs_correction": False,
        "human_approved": None,
        "human_feedback": None,
        "needs_modification": False
    }
    
    for attempt in range(max_attempts):
        print(f"\nAttempt {attempt + 1}/{max_attempts}")
        result_state = app.invoke(state, {"recursion_limit": 50})
        
        # Ensure result_state is not None
        if result_state is None:
            print("Warning: Workflow returned None state. Using previous state.")
            break
        else:
            state = result_state  
        
        # Handle early exit if human rejects final attempt
        if state.get("human_approved") is False and attempt == max_attempts - 1:
            return "Query rejected by human reviewer"
        
        if not state.get("needs_correction", False):
            result = state.get("query_result")
            if isinstance(result, pd.DataFrame):
                return result
            else:
                return f"Final result: {result}"
    
    return state.get("query_result", "Max attempts reached without success")

# Test questions
questions = [
    "Give me a list of top 5 users by engagement time"
]

for question in questions:
    print(f"\n{'='*50}\nProcessing: {question}\n{'='*50}")
    result = bigquery_agent(question)
    
    print("\nFinal Result:")
    if isinstance(result, pd.DataFrame):
        print(result.head())
    else:
        print(result)


Processing: Give me a list of top 5 users by engagement time

Attempt 1/3

Generating SQL for: Give me a list of top 5 users by engagement time
Generated SQL (raw): ```sql
SELECT 
    user_id, 
    SUM(engagement_time_minutes) AS total_engagement_time
FROM 
    `hallowed-span-459710-s1.test_clustering.user-engagement`
GROUP BY 
    user_id
ORDER BY 
    total_engagement_time DESC
LIMIT 5
```

HUMAN APPROVAL REQUIRED
Generated SQL Query:
```sql
SELECT 
    user_id, 
    SUM(engagement_time_minutes) AS total_engagement_time
FROM 
    `hallowed-span-459710-s1.test_clustering.user-engagement`
GROUP BY 
    user_id
ORDER BY 
    total_engagement_time DESC
LIMIT 5
```

Question being answered: Give me a list of top 5 users by engagement time

You can type:
- 'Y' or 'Yes' to approve
- 'N' or 'No' to reject
- 'Yes but...' followed by your feedback/suggestions to modify the query
Executing query: SELECT 
    user_id, 
    SUM(engagement_time_minutes) AS total_engagement_time
FROM 
    `hallowe

### LangGraph

LangChain Handles LLM Interactions:

- Generates the initial SQL

- Powers the correction mechanism

- Formats prompts with schema context

LangGraph Manages the Process:

- Retries failed queries with full context

- Maintains state (attempts, errors, last query)

- Decides when to terminate


Flow:
##### 1. User query:
- User types a question e.g Monthly active users by country
- LLM (Gemini) uses the schema-aware prompt (system_promot) to pick correct tables/columns then generates raw SQL query

##### 2. SQL Cleaning:
- LLMs often wrap SQL in markdown e.g (```sql SELECT ....) BigQuery would reject markdown-formatted SQL, so this standardizes input for validation
- clean_sql_query() removes formatting to SELECT ...

##### 3. Validation (Guardrails)
- It then checks for critical errors before execution: e.g Valid table references to prevent "table not found" errors or any other syntax errors

#### 4. Pre-Execution Approval (First Human Checkpoint)
- System pauses and shows the generated query to the human to approve or provide feedback
- ✅ Human Approve → Proceeds to execution
- Human says "Yes but..." → Suggest modifications (e.g., "Yes but order results in descendint order"). This is then sent to the LLM to make modifications.
- ❌ Reject → Triggers auto-correction


##### 4. BigQuery Execution
- Sends and runs approved cleaned SQL to BigQuery then return an output of that query in pandas format table

##### 5. SQL Query Results Analysis
- Checks for semantic issues the validator couldn't catch: e.g empty results , all zeros results 

##### 6. Self-Correction Loop
- If any suspicious results or execution failures were found, contextual feedback will be sent back to the LLM to correct
e.g what the LLM would receive "Fix this: Execution Error: Unrecognized column 'last_active_date'. Valid columns: ['user_id', 'signup_date', ...]
- LLM Outputs revised SQL query then restarts from validation

##### 7. Final Output
- Returns pandas DataFrame (ready for visualization)

In [None]:
# agent_orchestrator.py
from typing import Dict, Any, Optional, List, Union
from pydantic import BaseModel
import uuid
from datetime import datetime, timedelta
import pandas as pd
import importlib.util
import sys
import os
from pathlib import Path
from langchain_google_vertexai import ChatVertexAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# --------------------------
# Configuration
# --------------------------

GCP_PROJECT = "hallowed-span-459710-s1"
LOCATION = "us-central1"

# Set up credentials
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = "/Users/mncedisimncwabe/Downloads/hallowed-span-459710-s1-c41d79c9b56b.json"

# Initialize Vertex AI LLM
llm = ChatVertexAI(
    model_name="gemini-2.0-flash-001",
    temperature=0,
    max_output_tokens=2048,
    project=GCP_PROJECT,
    location=LOCATION
)

# --------------------------
# Protocol Definitions
# --------------------------

class ModelCapability(BaseModel):
    """Describes what an agent can do"""
    name: str
    description: str
    data_sources: List[str] = []  # What data this agent can access
    query_types: List[str] = []   # Types of queries it can handle
    input_schema: Dict[str, Any]
    output_schema: Dict[str, Any]
    parameters: Dict[str, Any] = {}

class ModelContext(BaseModel):
    """Standardized context object for MCP"""
    session_id: str
    timestamp: datetime
    context_data: Dict[str, Any]
    source_agent: Optional[str] = None
    target_agents: Optional[List[str]] = None

class A2AMessage(BaseModel):
    """Standardized agent-to-agent message"""
    message_id: str
    sender: str
    recipients: List[str]
    content: Dict[str, Any]
    context: Optional[ModelContext] = None
    requires_response: bool = False
    expiration: Optional[datetime] = None

# --------------------------
# Wrappers
# --------------------------

class MCPWrapper:
    """Wraps existing agents with MCP capabilities"""
    
    def __init__(self, agent_instance, agent_name: str, agent_description: str, 
                 data_sources: List[str] = None, query_types: List[str] = None):
        self.agent = agent_instance
        self.agent_name = agent_name
        self.agent_description = agent_description
        self.data_sources = data_sources or []
        self.query_types = query_types or []
        self.capabilities = self._define_capabilities()
        
    def _define_capabilities(self) -> List[ModelCapability]:
        """Define what this agent can do with detailed information"""
        return [
            ModelCapability(
                name="sql_generation_and_execution",
                description=f"Generates and executes SQL queries for {', '.join(self.data_sources)} data",
                data_sources=self.data_sources,
                query_types=self.query_types,
                input_schema={
                    "type": "object",
                    "properties": {
                        "question": {"type": "string"},
                        "max_attempts": {"type": "integer", "default": 3}
                    },
                    "required": ["question"]
                },
                output_schema={
                    "type": "object",
                    "properties": {
                        "sql_query": {"type": "string"},
                        "query_result": {"type": ["object", "string"]},
                        "status": {"type": "string"}
                    }
                }
            )
        ]
    
    def execute(self, context: ModelContext) -> ModelContext:
        """Execute agent with MCP context"""
        question = context.context_data.get("question")
        max_attempts = context.context_data.get("max_attempts", 3)
        
        # Call the original agent
        try:
            if "spanner" in self.agent_name.lower():
                result = self.agent.spanner_agent(question, max_attempts)
            else:
                result = self.agent.bigquery_agent(question, max_attempts)
            
            # Convert result to dict if it's a DataFrame
            if isinstance(result, pd.DataFrame):
                result_data = result.to_dict(orient='records')
            else:
                result_data = str(result)
                
            status = "success"
        except Exception as e:
            result_data = str(e)
            status = "error"
        
        # Package results into MCP format
        result_context = ModelContext(
            session_id=context.session_id,
            timestamp=datetime.now(),
            context_data={
                "question": question,
                "result": result_data,
                "status": status
            },
            source_agent=self.agent_name
        )
        
        return result_context

class A2AWrapper:
    """Enables agent-to-agent communication"""
    
    def __init__(self, mcp_wrapper: MCPWrapper):
        self.mcp_agent = mcp_wrapper
        self.inbox: List[A2AMessage] = []
        
    def send_message(self, recipients: List[str], content: Dict[str, Any], requires_response: bool = False) -> A2AMessage:
        """Send a message to other agents"""
        message = A2AMessage(
            message_id=str(uuid.uuid4()),
            sender=self.mcp_agent.agent_name,
            recipients=recipients,
            content=content,
            requires_response=requires_response,
            expiration=datetime.now() + timedelta(hours=1)
        )
        return message
        
    def receive_message(self, message: A2AMessage):
        """Receive a message from another agent"""
        self.inbox.append(message)
        
    def process_messages(self) -> Optional[A2AMessage]:
        """Process all received messages"""
        responses = []
        
        for message in self.inbox[:]: 
            if message.expiration and message.expiration < datetime.now():
                self.inbox.remove(message)
                continue
                
            if message.requires_response:
                # Create context from message
                context = ModelContext(
                    session_id=message.message_id,
                    timestamp=datetime.now(),
                    context_data=message.content,
                    source_agent=message.sender,
                    target_agents=[self.mcp_agent.agent_name]
                )
                
                # Execute the agent
                response_context = self.mcp_agent.execute(context)
                
                # Send response
                response = A2AMessage(
                    message_id=str(uuid.uuid4()),
                    sender=self.mcp_agent.agent_name,
                    recipients=[message.sender],
                    content=response_context.context_data,
                    context=response_context
                )
                
                responses.append(response)
                self.inbox.remove(message)
        
        return responses[0] if responses else None

# --------------------------
# Registry with LLM Routing
# --------------------------

class AgentRegistry:
    """Central registry for agent discovery and management with LLM-based routing"""
    
    def __init__(self, llm: ChatVertexAI, heartbeat_timeout: int = 300):
        self.agents: Dict[str, Dict[str, Any]] = {}
        self.heartbeat_timeout = heartbeat_timeout
        self.llm = llm
        
        # Enhanced LLM prompt for agent matching
        self.routing_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert at matching tasks to agent capabilities. Analyze the task description and available agents to determine the best match.

Available Agents:
{agents}

Instructions:
1. Match the task to agent capabilities, data sources, and query types
2. Consider keywords like "users", "engagement", "countries", "transactions"
3. Return ONLY the exact agent name that best matches
4. If multiple agents could work, pick the most specialized one
5. If no good match exists, return "NO_MATCH"

Examples:
- "top 5 users by engagement" → look for agents with user/engagement data
- "countries by transaction count" → look for agents with transaction/country data
- "sales analysis" → look for agents with sales data"""),
            ("human", "Task: {task}")
        ])
        
    def register_agent(self, mcp_wrapper: MCPWrapper, endpoint: Optional[str] = None):
        """Register an agent in the registry"""
        self.agents[mcp_wrapper.agent_name] = {
            "description": mcp_wrapper.agent_description,
            "capabilities": [cap.dict() for cap in mcp_wrapper.capabilities],
            "data_sources": mcp_wrapper.data_sources,
            "query_types": mcp_wrapper.query_types,
            "endpoint": endpoint,
            "last_heartbeat": datetime.now(),
            "wrapper": mcp_wrapper,
            "active": True
        }
        
    def update_heartbeat(self, agent_name: str):
        """Update the last active timestamp for an agent"""
        if agent_name in self.agents:
            self.agents[agent_name]["last_heartbeat"] = datetime.now()
            self.agents[agent_name]["active"] = True
            
    def check_agent_online(self, agent_name: str) -> bool:
        """Check if an agent is currently online"""
        agent = self.agents.get(agent_name)
        if not agent:
            return False
        return (datetime.now() - agent["last_heartbeat"]).total_seconds() < self.heartbeat_timeout
        
    def find_best_agent_for_task(self, task_description: str) -> Optional[str]:
        """
        Use LLM to find the best agent for a task
        Returns agent name or None if no suitable agent found
        """
        # Format agent information for the prompt with detailed capabilities
        agents_info = []
        for name, details in self.agents.items():
            if not self.check_agent_online(name):
                continue
            
            # Build detailed agent description
            agent_desc = []
            agent_desc.append(f"Agent: {name}")
            agent_desc.append(f"Description: {details['description']}")
            
            if details['data_sources']:
                agent_desc.append(f"Data Sources: {', '.join(details['data_sources'])}")
            if details['query_types']:
                agent_desc.append(f"Query Types: {', '.join(details['query_types'])}")
                
            capabilities = [cap['name'] for cap in details['capabilities']]
            agent_desc.append(f"Capabilities: {', '.join(capabilities)}")
            
            last_active_mins = (datetime.now() - details['last_heartbeat']).total_seconds() / 60
            agent_desc.append(f"Status: ONLINE (last active {last_active_mins:.1f} minutes ago)")
            
            agents_info.append('\n'.join(agent_desc))
        
        if not agents_info:
            print("DEBUG: No online agents found")
            return None
            
        agents_text = '\n\n'.join(agents_info)
        print(f"DEBUG: Sending to LLM:\nTask: {task_description}\nAgents:\n{agents_text}")
        
        # Get LLM routing
        try:
            chain = self.routing_prompt | self.llm | StrOutputParser()
            recommended_agent = chain.invoke({
                "agents": agents_text, 
                "task": task_description
            }).strip()
            
            print(f"DEBUG: LLM recommended: '{recommended_agent}'")
            
            # Validate the routing
            if recommended_agent == "NO_MATCH" or recommended_agent not in self.agents:
                print(f"DEBUG: No valid match found. Available agents: {list(self.agents.keys())}")
                return None
                
            return recommended_agent
            
        except Exception as e:
            print(f"DEBUG: Error in LLM routing: {e}")
            return None
        
    def route_task(self, task_description: str) -> Optional[ModelContext]:
        """
        Automated task routing
        1. Finds best agent
        2. Creates execution context
        3. Returns result context
        """
        agent_name = self.find_best_agent_for_task(task_description)
        if not agent_name:
            return None
            
        # Create and execute context
        context = ModelContext(
            session_id=str(uuid.uuid4()),
            timestamp=datetime.now(),
            context_data={
                "question": task_description,
                "max_attempts": 3
            },
            target_agents=[agent_name]
        )
        
        wrapper = self.agents[agent_name]["wrapper"]
        return wrapper.execute(context)
    
    def discover_agents(self) -> List[str]:
        """Discover all registered agents"""
        return list(self.agents.keys())
    
    def get_agent(self, agent_name: str) -> Optional[Dict[str, Any]]:
        """Get agent information"""
        return self.agents.get(agent_name)
    
    def get_agent_wrapper(self, agent_name: str) -> Optional[MCPWrapper]:
        """Get agent wrapper for direct execution"""
        agent_info = self.agents.get(agent_name)
        return agent_info["wrapper"] if agent_info else None

# --------------------------
# Test Cases
# --------------------------

def test_llm_routing():
    """Test LLM-based agent routing"""
    print("\n=== Testing LLM Routing ===")
    
    tasks = [
        "Give me a list of top 5 users by engagement time",
        "Show me the top 5 countries by transaction count"
    ]
    
    for task in tasks:
        print(f"\nTask: {task}")
        best_agent = registry.find_best_agent_for_task(task)
        if best_agent:
            print(f"✅ Recommended agent: {best_agent}")
            
            # Execute the task
            print("Executing task...")
            result = registry.route_task(task)
            if result:
                print(f"Status: {result.context_data['status']}")
                if result.context_data['status'] == 'success':
                    result_str = str(result.context_data['result'])
                    print(f"Result preview: {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
                else:
                    print(f"Error: {result.context_data['result']}")
        else:
            print("❌ No suitable agent found")

def test_heartbeat_system():
    """Test agent heartbeat monitoring"""
    print("\n=== Testing Heartbeat System ===")
    
    # Check initial status
    print("\nAgent statuses:")
    for name in registry.agents:
        status = "ONLINE" if registry.check_agent_online(name) else "OFFLINE"
        last_seen = registry.agents[name]["last_heartbeat"].strftime('%H:%M:%S')
        print(f"- {name}: {status} (last seen {last_seen})")
    
    # Simulate agent going offline
    print("\nSimulating Spanner agent going offline...")
    offline_time = datetime.now() - timedelta(seconds=400)
    registry.agents["SpannerSQLAgent"]["last_heartbeat"] = offline_time
    
    # Verify routing skips offline agents
    task = "Get transaction counts by country"
    print(f"\nRouting task: {task}")
    best_agent = registry.find_best_agent_for_task(task)
    print(f"Selected agent: {best_agent}")
    
    # Reset agent status
    registry.update_heartbeat("SpannerSQLAgent")


def test_a2a_communication():
    """Test agent-to-agent communication"""
    print("\n=== Testing Agent-to-Agent Communication ===")
    
    # Create A2A wrappers
    spanner_a2a = A2AWrapper(spanner_mcp)
    bq_a2a = A2AWrapper(bq_mcp)
    
    # Test message sending
    print("\n1. Testing message sending between agents")
    
    # Spanner agent sends a message to BigQuery agent
    message = spanner_a2a.send_message(
        recipients=["BigQuerySQLAgent"],
        content={
            "request_type": "data_validation",
            "query": "Can you validate user engagement data for user_id '6c3dbd5cb2393a74d1b5d1fc3289f4b92deea4f92b9b2994399aabf172c500d5'?",
            "context": "Cross-referencing transaction data with engagement metrics"
        },
        requires_response=True
    )
    
    print(f"✅ Message sent from {message.sender} to {message.recipients}")
    print(f"   Message ID: {message.message_id}")
    print(f"   Content: {message.content['request_type']}")
    
    # BigQuery agent receives the message
    bq_a2a.receive_message(message)
    print(f"✅ Message received by BigQuerySQLAgent")
    print(f"   Inbox size: {len(bq_a2a.inbox)}")
    
    # Process messages
    print("\n2. Processing messages")
    response = bq_a2a.process_messages()
    
    if response:
        print(f"✅ Response generated by {response.sender}")
        print(f"   Response to: {response.recipients}")
        print(f"   Response content preview: {str(response.content)[:100]}...")
    else:
        print("❌ No response generated")

def test_advanced_a2a_scenarios():
    """Test advanced agent-to-agent scenarios"""
    print("\n=== Testing Advanced A2A Scenarios ===")
    
    # Create A2A wrappers
    spanner_a2a = A2AWrapper(spanner_mcp)
    bq_a2a = A2AWrapper(bq_mcp)
    
    print("\n1. Testing data correlation scenario")
    
    # Scenario: Find users with high engagement but low transaction counts
    correlation_message = spanner_a2a.send_message(
        recipients=["BigQuerySQLAgent"],
        content={
            "request_type": "data_correlation",
            "question": "Get top 10 users by engagement time",
            "purpose": "Will cross-reference with transaction data"
        },
        requires_response=True
    )
    
    print("✅ Correlation request sent to BigQuery agent")
    
    # BigQuery processes the request
    bq_a2a.receive_message(correlation_message)
    engagement_response = bq_a2a.process_messages()
    
    if engagement_response:
        print("✅ Engagement data retrieved")
        
        # Spanner agent uses this data to find patterns
        follow_up_context = ModelContext(
            session_id=str(uuid.uuid4()),
            timestamp=datetime.now(),
            context_data={
                "question": "Show transaction counts for users with high engagement",
                "context_data": engagement_response.content
            }
        )
        
        spanner_result = spanner_mcp.execute(follow_up_context)
        print(f"✅ Cross-reference analysis completed")
        print(f"   Status: {spanner_result.context_data['status']}")
    
    print("\n2. Testing message expiration")
    
    # Create message with short expiration
    expired_message = A2AMessage(
        message_id=str(uuid.uuid4()),
        sender="TestAgent",
        recipients=["SpannerSQLAgent"],
        content={"test": "This message should expire"},
        expiration=datetime.now() - timedelta(seconds=1)  
    )
    
    spanner_a2a.receive_message(expired_message)
    print(f"✅ Expired message added to inbox (size: {len(spanner_a2a.inbox)})")
    
    spanner_a2a.process_messages()
    print(f"✅ After processing: inbox size: {len(spanner_a2a.inbox)}")
    
    print("\n3. Testing multi-agent broadcast")
    
    # Create a broadcast message
    broadcast_message = bq_a2a.send_message(
        recipients=["SpannerSQLAgent"],
        content={
            "broadcast_type": "system_status",
            "message": "System maintenance scheduled",
            "timestamp": datetime.now().isoformat()
        },
        requires_response=False
    )
    
    print(f"✅ Broadcast message created")
    print(f"   Recipients: {broadcast_message.recipients}")
    print(f"   Requires response: {broadcast_message.requires_response}")

def test_registry_management():
    """Test registry management functions"""
    print("\n=== Testing Registry Management ===")
    
    print("\n1. Agent Discovery")
    agents = registry.discover_agents()
    print(f"✅ Discovered agents: {agents}")
    
    print("\n2. Agent Details")
    for agent_name in agents:
        agent_info = registry.get_agent(agent_name)
        if agent_info:
            print(f"\n{agent_name}:")
            print(f"   Description: {agent_info['description']}")
            print(f"   Data Sources: {agent_info['data_sources']}")
            print(f"   Query Types: {agent_info['query_types']}")
            print(f"   Online: {registry.check_agent_online(agent_name)}")
    
    print("\n3. Capability Matching")
    test_queries = [
        "Find users with most transactions",
        "Analyze engagement patterns over time"
    ]
    
    for query in test_queries:
        best_agent = registry.find_best_agent_for_task(query)
        print(f"   '{query}' → {best_agent or 'No match'}")

def test_error_handling():
    """Test error handling scenarios"""
    print("\n=== Testing Error Handling ===")
    
    print("\n1. Invalid SQL query handling")
    error_context = ModelContext(
        session_id=str(uuid.uuid4()),
        timestamp=datetime.now(),
        context_data={
            "question": "SELECT * FROM nonexistent_table_xyz",
            "max_attempts": 1
        }
    )
    
    try:
        result = spanner_mcp.execute(error_context)
        print(f"✅ Error handled gracefully")
        print(f"   Status: {result.context_data['status']}")
        print(f"   Error message: {result.context_data['result'][:100]}...")
    except Exception as e:
        print(f"❌ Unhandled error: {e}")
    
    print("\n2. Offline agent handling")
    # Temporarily mark agent as offline
    original_heartbeat = registry.agents["SpannerSQLAgent"]["last_heartbeat"]
    registry.agents["SpannerSQLAgent"]["last_heartbeat"] = datetime.now() - timedelta(seconds=400)
    
    task_result = registry.route_task("Get transaction data")
    if task_result is None:
        print("✅ Offline agent correctly excluded from routing")
    else:
        print("❌ Offline agent was still used")
    
    # Restore agent status
    registry.agents["SpannerSQLAgent"]["last_heartbeat"] = original_heartbeat

# --------------------------
# Main Execution
# --------------------------

if __name__ == "__main__":
    # Load existing agents
    def load_agent(agent_file: str):
        """Dynamically load an agent module"""
        module_name = Path(agent_file).stem
        spec = importlib.util.spec_from_file_location(module_name, agent_file)
        module = importlib.util.module_from_spec(spec)
        sys.modules[module_name] = module
        spec.loader.exec_module(module)
        return module

    try:
        spanner_agent = load_agent("spanner_agent.py")
        bq_agent = load_agent("bq_agent.py")
    except Exception as e:
        print(f"Error loading agents: {e}")
        sys.exit(1)

    # Initialize wrappers with detailed metadata
    spanner_mcp = MCPWrapper(
        agent_instance=spanner_agent,
        agent_name="SpannerSQLAgent",
        agent_description="Generates and executes Google Cloud Spanner SQL queries for transactional data",
        data_sources=["transactions", "users", "countries", "payments", "orders"],
        query_types=["aggregation", "filtering", "grouping", "joins", "analytics"]
    )
    
    bq_mcp = MCPWrapper(
        agent_instance=bq_agent,
        agent_name="BigQuerySQLAgent", 
        agent_description="Generates and executes BigQuery SQL queries for analytics and reporting",
        data_sources=["user_engagement", "web_analytics", "logs", "events", "metrics"],
        query_types=["analytics", "reporting", "time_series", "aggregation", "data_warehouse"]
    )

    # Initialize registry with LLM
    registry = AgentRegistry(llm=llm, heartbeat_timeout=300)
    registry.register_agent(spanner_mcp)
    registry.register_agent(bq_mcp)

    # Run tests
    test_llm_routing()
    test_a2a_communication()
    test_advanced_a2a_scenarios()
    test_heartbeat_system()