In [1]:
# %% [markdown]
# # LangGraph + Multi-Vector Retriever + MCP SQL Client
#
# This notebook-style script shows:
# 1. MCP-like client/server for SQL:
#    - `run_sql`
#    - `validate_sql`
# 2. Multi-vector retriever:
#    - One FAISS store per document in ../data/data_information
#    - A multi-store retriever with global ranking
# 3. LangGraph pipeline:
#    - Plan → Execute steps (retriever + MCP SQL) → Retry / Replan on failure
#    - Return final dataframe or best partial result
#    - Optionally summarize data using LLM
# 4. Test cases:
#    - Retriever accuracy (Precision, Recall, MAP, MRR, nDCG)
#    - Simple pipeline tests for LLM + graph

# %%
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Literal, Set

import sqlite3
import json
import textwrap
import traceback

import pandas as pd

# LangChain / embeddings / vector stores
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage, SystemMessage

# LangGraph
from langgraph.graph import StateGraph, END

# LLM (Ollama-based, similar to the sample notebook)
try:
    from langchain_ollama import OllamaLLM
    print("Using langchain_ollama.OllamaLLM")
except ImportError:
    from langchain_community.llms import Ollama as OllamaLLM
    print("Using legacy langchain_community.llms.Ollama")

# -----------------------------------------------------------------------------
# 0. LLM Setup (Ollama)
# -----------------------------------------------------------------------------

# You can change this to whatever you have installed in Ollama
MODEL_NAME = "qwen2.5-coder:7b"

llm = OllamaLLM(
    model=MODEL_NAME,
    # num_ctx is supported by some models; harmless if ignored
    num_ctx=4096,
)

def call_llm(
    system_prompt: str,
    user_prompt: str,
    max_retries: int = 3,
) -> str:
    """Small helper to call the LLM with retry logic."""
    messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=user_prompt),
    ]
    last_error: Optional[Exception] = None
    for attempt in range(1, max_retries + 1):
        try:
            print(f"[LLM] Calling model (attempt {attempt}/{max_retries})...")
            out = llm.invoke(messages)
            text = out if isinstance(out, str) else getattr(out, "content", "")
            text = (text or "").strip()
            if text:
                return text
        except Exception as e:
            last_error = e
            print(f"[LLM] Error: {e}")
    raise RuntimeError(f"LLM call failed after {max_retries} attempts: {last_error}")


# -----------------------------------------------------------------------------
# 1. Minimal MCP-like abstractions: Tool, Server, Client
#    (SQL run + SQL validate)
# -----------------------------------------------------------------------------

@dataclass
class Tool:
    name: str
    description: str
    func: Any  # func: Callable[[Dict[str, Any]], Any]

    def __call__(self, args: Dict[str, Any]) -> Any:
        return self.func(args)


class MCPServer:
    """
    Very small in-process 'server' that:
      - registers tools
      - handles incoming tool invocation requests
    """

    def __init__(self) -> None:
        self.tools: Dict[str, Tool] = {}

    def register_tool(self, tool: Tool) -> None:
        if tool.name in self.tools:
            raise ValueError(f"Tool {tool.name!r} already registered")
        self.tools[tool.name] = tool

    def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
        """
        Request format:
        {
            "id": "<string>",
            "tool": "<tool_name>",
            "args": { ... }
        }
        """
        req_id = request.get("id")
        tool_name = request.get("tool")
        args = request.get("args") or {}

        if tool_name not in self.tools:
            return {
                "id": req_id,
                "ok": False,
                "error": f"Unknown tool: {tool_name}",
            }

        tool = self.tools[tool_name]
        try:
            result = tool(args)
            return {
                "id": req_id,
                "ok": True,
                "result": result,
            }
        except Exception as e:
            tb = traceback.format_exc()
            return {
                "id": req_id,
                "ok": False,
                "error": str(e),
                "traceback": tb,
            }


class MCPClient:
    """
    Very small in-process client that sends requests to MCPServer.
    """

    def __init__(self, server: MCPServer) -> None:
        self.server = server
        self._next_id = 1

    def call_tool(self, tool_name: str, args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        if args is None:
            args = {}
        req_id = str(self._next_id)
        self._next_id += 1
        request = {
            "id": req_id,
            "tool": tool_name,
            "args": args,
        }
        response = self.server.handle_request(request)
        return response


# -----------------------------------------------------------------------------
# 2. SQL Tools: run_sql & validate_sql
# -----------------------------------------------------------------------------

DB_PATH = Path("../data") / "Employee_Information.db"  # adjust if needed
print(f"[MCP] DB path: {DB_PATH!r}")
print(f"[MCP] Available files in ../data: {list(DB_PATH.parent.glob('*'))}")

def _get_connection() -> sqlite3.Connection:
    if not DB_PATH.exists():
        raise FileNotFoundError(f"Database not found at {DB_PATH!r}")
    conn = sqlite3.connect(DB_PATH)
    return conn

def run_sql_tool(args: Dict[str, Any]) -> Dict[str, Any]:
    """
    args:
      {
        "query": "<SQL>",
        "params": [ ... ]  # optional
      }
    """
    query = args.get("query")
    if not isinstance(query, str):
        raise ValueError("run_sql: 'query' must be a string")

    params = args.get("params") or []

    with _get_connection() as conn:
        cur = conn.cursor()
        cur.execute(query, params)

        # Some statements don't return rows
        try:
            rows = cur.fetchall()
        except sqlite3.ProgrammingError:
            rows = []

        columns = [desc[0] for desc in (cur.description or [])]

    return {
        "query": query,
        "columns": columns,
        "rows": rows,
        "rowcount": len(rows),
    }

def validate_sql_tool(args: Dict[str, Any]) -> Dict[str, Any]:
    """
    args:
      {
        "query": "<SQL>"
      }
    We parse the SQL using sqlite3's parser by preparing a statement
    but NOT executing it.
    """
    query = args.get("query")
    if not isinstance(query, str):
        raise ValueError("validate_sql: 'query' must be a string")

    try:
        with _get_connection() as conn:
            conn.execute("EXPLAIN " + query)
        return {
            "query": query,
            "valid": True,
            "error": None,
        }
    except Exception as e:
        return {
            "query": query,
            "valid": False,
            "error": str(e),
        }

# Register tools on MCP server
mcp_server = MCPServer()
mcp_server.register_tool(Tool(
    name="run_sql",
    description="Execute SQL query on the SQLite database.",
    func=run_sql_tool,
))
mcp_server.register_tool(Tool(
    name="validate_sql",
    description="Validate SQL syntax using SQLite.",
    func=validate_sql_tool,
))
mcp_client = MCPClient(mcp_server)


# -----------------------------------------------------------------------------
# 3. Multi-Vector Retriever (one vector store per file + global ranking)
# -----------------------------------------------------------------------------

DATA_DIR = Path("../data/data_information")
if not DATA_DIR.exists():
    print(f"[Retriever] WARNING: {DATA_DIR} does not exist; adjust path if needed.")
else:
    print(f"[Retriever] Loading docs from: {DATA_DIR}")

def load_single_file(path: Path) -> List[Document]:
    loader = TextLoader(str(path), encoding="utf-8")
    docs = loader.load()
    for d in docs:
        d.metadata["source_file"] = path.name
    return docs

def build_multi_vector_stores(data_dir: Path) -> Dict[str, FAISS]:
    """
    Build one FAISS vector store per file in data_dir.
    Returns dict: filename -> FAISS vectorstore.
    """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=800,
        chunk_overlap=150,
        length_function=len,
    )

    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

    vectorstores_by_file: Dict[str, FAISS] = {}
    if not data_dir.exists():
        return vectorstores_by_file

    for file_path in data_dir.iterdir():
        if not file_path.is_file():
            continue
        docs = load_single_file(file_path)
        chunks = text_splitter.split_documents(docs)
        for c in chunks:
            c.metadata["source_file"] = file_path.name
        if chunks:
            vs = FAISS.from_documents(chunks, embeddings)
            vectorstores_by_file[file_path.name] = vs

    print(f"[Retriever] Built {len(vectorstores_by_file)} vectorstores.")
    return vectorstores_by_file


vectorstores_by_file = build_multi_vector_stores(DATA_DIR)


class MultiStoreRetriever:
    """
    Simple multi-vector retriever:
      - For each file's FAISS store, perform similarity search
      - Collect all (doc, score, filename)
      - Rank globally by score
    """

    def __init__(
        self,
        vs_by_file: Dict[str, FAISS],
        k_per_store: int = 4,
        k_total: int = 10,
    ) -> None:
        self.vs_by_file = vs_by_file
        self.k_per_store = k_per_store
        self.k_total = k_total

    def retrieve_with_scores(
        self, query: str
    ) -> List[Tuple[Document, float, str]]:
        results: List[Tuple[Document, float, str]] = []
        for filename, vs in self.vs_by_file.items():
            try:
                docs_scores = vs.similarity_search_with_score(
                    query, k=self.k_per_store
                )
                for doc, score in docs_scores:
                    results.append((doc, score, filename))
            except Exception as e:
                print(f"[Retriever] Error retrieving from {filename}: {e}")

        # FAISS similarity_search_with_score returns smaller scores = more similar
        results.sort(key=lambda t: t[1])
        return results[: self.k_total]

    def get_top_docs(self, query: str) -> List[Document]:
        return [d for (d, _, _) in self.retrieve_with_scores(query)]


mv_retriever = MultiStoreRetriever(vectorstores_by_file)


# -----------------------------------------------------------------------------
# 4. Retriever evaluation metrics (Precision/Recall/MAP/MRR/nDCG)
# -----------------------------------------------------------------------------

def precision_at_k(pred: List[str], relevant: Set[str], k: int) -> float:
    if k == 0:
        return 0.0
    pred_k = pred[:k]
    hits = sum(1 for p in pred_k if p in relevant)
    return hits / k

def recall_at_k(pred: List[str], relevant: Set[str], k: int) -> float:
    if not relevant:
        return 0.0
    pred_k = pred[:k]
    hits = sum(1 for p in pred_k if p in relevant)
    return hits / len(relevant)

def average_precision(pred: List[str], relevant: Set[str]) -> float:
    if not relevant:
        return 0.0
    ap = 0.0
    hits = 0
    for i, p in enumerate(pred, start=1):
        if p in relevant:
            hits += 1
            ap += hits / i
    return ap / len(relevant)

def reciprocal_rank(pred: List[str], relevant: Set[str]) -> float:
    for i, p in enumerate(pred, start=1):
        if p in relevant:
            return 1.0 / i
    return 0.0

def ndcg_at_k(pred: List[str], relevant: Set[str], k: int) -> float:
    """Binary relevance nDCG@k."""
    import math
    pred_k = pred[:k]
    dcg = 0.0
    for i, p in enumerate(pred_k, start=1):
        rel = 1.0 if p in relevant else 0.0
        if rel > 0:
            dcg += rel / math.log2(i + 1)

    # Ideal DCG
    ideal_rel = [1.0] * min(len(relevant), k)
    idcg = 0.0
    for i, rel in enumerate(ideal_rel, start=1):
        idcg += rel / math.log2(i + 1)

    return dcg / idcg if idcg > 0 else 0.0

def rank_sources_for_query(query: str, retriever: MultiStoreRetriever) -> List[str]:
    """
    Uses the multi-store retriever, then deduplicates by source_file.
    """
    results = retriever.retrieve_with_scores(query)
    ranked_sources: List[str] = []
    seen: Set[str] = set()
    for doc, score, filename in results:
        sf = doc.metadata.get("source_file", filename)
        if sf not in seen:
            ranked_sources.append(sf)
            seen.add(sf)
    return ranked_sources


# Example test cases based on the typical e-commerce tables in data_information
# (You can adjust relevant_sources if your file names differ.)
retriever_test_cases: List[Dict[str, Any]] = [
    {
        "name": "User demographics",
        "query": "What data describes the demographics of a user?",
        "relevant_sources": {
            "users.txt",
            "user_profiles.txt",
            "addresses.txt",
            "countries.txt",
        },
    },
    {
        "name": "Orders and payments",
        "query": "Which tables track orders, items in orders, and how they are paid?",
        "relevant_sources": {
            "orders.txt",
            "order_items.txt",
            "payments.txt",
        },
    },
    {
        "name": "Support tickets",
        "query": "Where is information about customer support issues stored?",
        "relevant_sources": {
            "support_tickets.txt",
            "users.txt",
        },
    },
    {
        "name": "Product catalog",
        "query": "Which table defines the available products and their prices?",
        "relevant_sources": {
            "products.txt",
        },
    },
]

def evaluate_retriever(
    test_cases: List[Dict[str, Any]],
    retriever: MultiStoreRetriever,
    top_n: int = 10,
    k_metrics: int = 5,
) -> Tuple[List[Dict[str, Any]], Dict[str, float]]:
    results: List[Dict[str, Any]] = []
    agg: Dict[str, float] = {
        "P@k": 0.0,
        "R@k": 0.0,
        "MAP": 0.0,
        "MRR": 0.0,
        "nDCG@k": 0.0,
    }

    for tc in test_cases:
        query = tc["query"]
        relevant = set(tc["relevant_sources"])
        ranked_sources = rank_sources_for_query(query, retriever)[:top_n]

        p = precision_at_k(ranked_sources, relevant, k_metrics)
        r = recall_at_k(ranked_sources, relevant, k_metrics)
        ap = average_precision(ranked_sources, relevant)
        rr = reciprocal_rank(ranked_sources, relevant)
        ndcg = ndcg_at_k(ranked_sources, relevant, k_metrics)

        agg["P@k"] += p
        agg["R@k"] += r
        agg["MAP"] += ap
        agg["MRR"] += rr
        agg["nDCG@k"] += ndcg

        results.append({
            "name": tc["name"],
            "query": query,
            "relevant_sources": relevant,
            "pred_sources": ranked_sources,
            "P@k": p,
            "R@k": r,
            "AP": ap,
            "RR": rr,
            "nDCG@k": ndcg,
        })

    n = max(len(test_cases), 1)
    summary = {k: v / n for k, v in agg.items()}
    return results, summary


# -----------------------------------------------------------------------------
# 5. LangGraph State + Nodes
# -----------------------------------------------------------------------------

class AgentState(TypedDict, total=False):
    query: str
    plan: List[Dict[str, Any]]
    step_index: int
    retries_for_step: int
    plan_attempts: int
    max_retries_per_step: int
    max_plan_attempts: int
    last_error: Optional[str]
    best_partial_step: Optional[int]
    best_partial_result: Optional[Dict[str, Any]]  # e.g. serialized df / info
    final_dataframe: Optional[pd.DataFrame]
    final_answer: Optional[str]
    status: Literal["planning", "running", "success", "failed"]


def planning_node(state: AgentState) -> AgentState:
    """
    Use LLM to generate a JSON plan: list of steps.
    Each step is a dict with fields like:
      - "id": int
      - "type": "sql" | "doc_retrieval"
      - "description": str
    """
    user_query = state["query"]
    system_prompt = """You are a planning assistant.

Given a user query, create a small JSON plan with a list of steps.
Each step must be a JSON object with fields:
- id: integer step number starting from 1
- type: one of ["sql", "doc_retrieval"]
- description: natural language description of what to do

Rules:
- Use type "sql" when you need to run SQL or inspect a database.
- Use type "doc_retrieval" when you need to consult documentation text files.
Return ONLY valid JSON: {"steps": [...]}"""

    user_prompt = f"User query: {user_query}\nCreate the plan."

    raw = call_llm(system_prompt, user_prompt)

    # Try to parse JSON; if it fails, wrap it
    try:
        plan_json = json.loads(raw)
    except Exception:
        # try to extract JSON object substring
        start = raw.find("{")
        end = raw.rfind("}")
        if start >= 0 and end > start:
            try:
                plan_json = json.loads(raw[start : end + 1])
            except Exception:
                raise RuntimeError(f"Could not parse plan as JSON:\n{raw}")
        else:
            raise RuntimeError(f"Could not parse plan as JSON:\n{raw}")

    steps = plan_json.get("steps", [])
    if not isinstance(steps, list) or not steps:
        raise RuntimeError(f"Plan contains no steps: {plan_json}")

    print("[Planner] Plan created:")
    print(json.dumps(steps, indent=2)[:1000])

    new_state: AgentState = dict(state)
    new_state["plan"] = steps
    new_state["step_index"] = 0
    new_state["retries_for_step"] = 0
    new_state["plan_attempts"] = state.get("plan_attempts", 0) + 1
    new_state["status"] = "running"
    return new_state


def execute_step_node(state: AgentState) -> AgentState:
    """
    Execute the current step:
      - 'sql' steps via MCP client (validate_sql + run_sql)
      - 'doc_retrieval' steps via multi-vector retriever
    """
    plan = state["plan"]
    idx = state["step_index"]
    if idx >= len(plan):
        # nothing to do
        return state

    step = plan[idx]
    step_type = step.get("type")
    desc = step.get("description", "")
    print(f"[Executor] Executing step {idx+1}/{len(plan)}: type={step_type}, desc={desc!r}")

    new_state: AgentState = dict(state)
    try:
        if step_type == "sql":
            # Use LLM to propose a SQL query given description + user query
            sql_prompt = f"""
You are a SQL expert. User query: {state['query']}

Current step description:
{desc}

Write a single valid SQLite SQL query (no explanation). Use tables from the Employee_Information.db schema.
Return ONLY the SQL.
"""
            sql_query = call_llm(
                "You generate pure SQL queries for SQLite.",
                sql_prompt,
            )

            # validate
            val_resp = mcp_client.call_tool("validate_sql", {"query": sql_query})
            if not val_resp.get("ok"):
                raise RuntimeError(f"validate_sql failed: {val_resp.get('error')}")
            val_payload = val_resp["result"]
            if not val_payload.get("valid", False):
                raise RuntimeError(f"SQL invalid: {val_payload.get('error')}")

            # execute
            run_resp = mcp_client.call_tool("run_sql", {"query": sql_query})
            if not run_resp.get("ok"):
                raise RuntimeError(f"run_sql failed: {run_resp.get('error')}")

            rows = run_resp["result"]["rows"]
            cols = run_resp["result"]["columns"]
            df = pd.DataFrame(rows, columns=cols)

            new_state["final_dataframe"] = df
            new_state["best_partial_step"] = idx
            new_state["best_partial_result"] = {
                "type": "sql",
                "step_index": idx,
                "sql": sql_query,
                "columns": cols,
                "rows_preview": rows[:10],
            }
            new_state["last_error"] = None
            new_state["retries_for_step"] = 0
            new_state["step_index"] = idx + 1

        elif step_type == "doc_retrieval":
            query_for_docs = f"{state['query']} -- {desc}"
            docs = mv_retriever.get_top_docs(query_for_docs)
            preview = [
                {
                    "source_file": d.metadata.get("source_file", "unknown"),
                    "snippet": d.page_content[:200],
                }
                for d in docs
            ]
            new_state["best_partial_step"] = idx
            new_state["best_partial_result"] = {
                "type": "doc_retrieval",
                "step_index": idx,
                "results_preview": preview,
            }
            new_state["last_error"] = None
            new_state["retries_for_step"] = 0
            new_state["step_index"] = idx + 1
        else:
            raise RuntimeError(f"Unknown step type: {step_type}")

    except Exception as e:
        err_msg = f"{type(e).__name__}: {e}"
        print(f"[Executor] ERROR: {err_msg}")
        new_state["last_error"] = err_msg
        new_state["retries_for_step"] = state.get("retries_for_step", 0) + 1

    return new_state


def control_router(state: AgentState) -> Literal["execute_step", "plan", "summary", "failed", END]:
    """
    Decide what to do next based on:
      - remaining steps
      - retries_for_step vs max_retries_per_step
      - plan_attempts vs max_plan_attempts
      - current status
    """
    status = state.get("status", "planning")
    retries = state.get("retries_for_step", 0)
    max_retries = state.get("max_retries_per_step", 3)
    plan_attempts = state.get("plan_attempts", 0)
    max_plan_attempts = state.get("max_plan_attempts", 2)
    plan = state.get("plan", [])
    idx = state.get("step_index", 0)

    if status == "planning" or not plan:
        if plan_attempts >= max_plan_attempts:
            print("[Router] Plan attempts exhausted; failing.")
            return "failed"
        print("[Router] Need to (re)plan.")
        return "plan"

    # We have a plan:
    if idx < len(plan):
        # More steps remaining
        if retries > max_retries:
            # Step failed too many times; replan if allowed
            if plan_attempts >= max_plan_attempts:
                print("[Router] Step retries exhausted and no more plan attempts; failing.")
                return "failed"
            print("[Router] Step retries exhausted; replan.")
            # Reset step index & retries; planning_node will overwrite plan
            state["step_index"] = 0
            state["retries_for_step"] = 0
            state["status"] = "planning"
            return "plan"
        # Otherwise execute current step
        return "execute_step"

    # All steps complete
    print("[Router] All steps completed; move to summary.")
    return "summary"


def summary_node(state: AgentState) -> AgentState:
    """
    If user asked a question, produce a natural language answer using LLM
    and the resulting dataframe (if any). Otherwise just keep dataframe.
    """
    df = state.get("final_dataframe")
    query = state["query"]

    # Heuristic: if query ends with ? or starts with typical question word, we summarize.
    question_like = (
        query.strip().endswith("?")
        or query.lower().startswith(("what", "how", "why", "which", "who", "when"))
    )

    if df is not None and question_like:
        # Limit rows for prompt
        preview_df = df.head(20)
        table_md = preview_df.to_markdown(index=False)

        system_prompt = "You are a data analyst who explains query results clearly and concisely."
        user_prompt = f"""
User question:
{query}

Here is a sample of the dataframe result (up to 20 rows):

{table_md}

Explain the answer to the user based on this data.
Respond in a few short paragraphs plus any key bullet points if helpful.
"""

        answer = call_llm(system_prompt, user_prompt)
        new_state = dict(state)
        new_state["final_answer"] = answer
        new_state["status"] = "success"
        return new_state

    # No dataframe or not clearly a question → just mark success
    new_state = dict(state)
    new_state["status"] = "success"
    return new_state


def failed_node(state: AgentState) -> AgentState:
    """
    Final node if we exhaust retries and plan attempts.
    We keep the best partial result to show progress.
    """
    print("[Failed] Pipeline exhausted retries and plan attempts.")
    new_state = dict(state)
    new_state["status"] = "failed"
    return new_state


# -----------------------------------------------------------------------------
# 6. Build the LangGraph
# -----------------------------------------------------------------------------

builder = StateGraph(AgentState)

builder.add_node("plan", planning_node)
builder.add_node("execute_step", execute_step_node)
builder.add_node("summary", summary_node)
builder.add_node("failed", failed_node)

# Control router as a conditional entry point
builder.set_conditional_entry_point(
    control_router,
    {
        "plan": "plan",
        "execute_step": "execute_step",
        "summary": "summary",
        "failed": "failed",
        END: END,
    },
)

# After planning, route back to control router
builder.add_conditional_edges(
    "plan",
    control_router,
    {
        "execute_step": "execute_step",
        "plan": "plan",
        "summary": "summary",
        "failed": "failed",
        END: END,
    },
)

# After each step execution, route back to control router
builder.add_conditional_edges(
    "execute_step",
    control_router,
    {
        "execute_step": "execute_step",
        "plan": "plan",
        "summary": "summary",
        "failed": "failed",
        END: END,
    },
)

# Summary and failed both go to END
builder.add_edge("summary", END)
builder.add_edge("failed", END)

graph = builder.compile()


def run_pipeline(user_query: str) -> AgentState:
    """
    Convenience wrapper to run the graph once.
    """
    init_state: AgentState = {
        "query": user_query,
        "plan": [],
        "step_index": 0,
        "retries_for_step": 0,
        "plan_attempts": 0,
        "max_retries_per_step": 3,
        "max_plan_attempts": 2,
        "status": "planning",
        "last_error": None,
        "best_partial_step": None,
        "best_partial_result": None,
        "final_dataframe": None,
        "final_answer": None,
    }

    final_state = graph.invoke(init_state)
    return final_state


# -----------------------------------------------------------------------------
# 7. Test Cases
# -----------------------------------------------------------------------------

def run_retriever_tests():
    if not vectorstores_by_file:
        print("[Tests] Skipping retriever tests; no vectorstores built.")
        return

    print("\n=== Retriever Evaluation ===")
    results, summary = evaluate_retriever(
        retriever_test_cases,
        mv_retriever,
        top_n=10,
        k_metrics=5,
    )

    print("Summary metrics:")
    for k, v in summary.items():
        print(f"  {k}: {v:.3f}")

    print("\nPer-query results:")
    for r in results:
        print(f"\nTest case: {r['name']}")
        print(f"  Query: {r['query']}")
        print(f"  P@k: {r['P@k']:.3f}, R@k: {r['R@k']:.3f}, AP: {r['AP']:.3f}, RR: {r['RR']:.3f}, nDCG@k: {r['nDCG@k']:.3f}")
        print(f"  Relevant: {sorted(r['relevant_sources'])}")
        print(f"  Predicted: {r['pred_sources']}")


def run_pipeline_tests():
    """
    Very lightweight tests of the end-to-end pipeline.
    These assume the Employee_Information.db exists and is sensible.
    """
    test_queries = [
        "List all departments and how many employees they have.",
        "What is the average salary per department?",
        "Show the top 5 highest paid employees and their departments.",
    ]
    for q in test_queries:
        print("\n" + "=" * 80)
        print(f"[Pipeline Test] Query: {q}")
        try:
            state = run_pipeline(q)
            print(f"  Status: {state.get('status')}")
            if state.get("final_answer"):
                print("\nFinal Answer:")
                print(state["final_answer"])
            elif state.get("final_dataframe") is not None:
                print("\nFinal DataFrame (head):")
                print(state["final_dataframe"].head())
            else:
                print("\nNo final answer or dataframe; best partial result:")
                print(json.dumps(state.get("best_partial_result"), indent=2)[:1000])
        except Exception as e:
            print(f"  ERROR: {e}")


# -----------------------------------------------------------------------------
# 8. Example usage (manual + tests)
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Example manual run
    example_query = "For each department, what is the average salary and number of employees?"
    print("\n" + "#" * 80)
    print(f"Manual example query: {example_query}")
    final_state = run_pipeline(example_query)
    print(f"Final status: {final_state.get('status')}")

    if final_state.get("final_answer"):
        print("\nFinal answer:")
        print(final_state["final_answer"])
    elif final_state.get("final_dataframe") is not None:
        print("\nFinal DataFrame (head):")
        print(final_state["final_dataframe"].head())
    else:
        print("\nBest partial result:")
        print(json.dumps(final_state.get("best_partial_result"), indent=2)[:1000])

    # Run retriever tests
    run_retriever_tests()

    # Run pipeline tests (LLM + graph + MCP)
    run_pipeline_tests()


  from .autonotebook import tqdm as notebook_tqdm


Using langchain_ollama.OllamaLLM
[MCP] DB path: PosixPath('../data/Employee_Information.db')
[MCP] Available files in ../data: [PosixPath('../data/create_db.ipynb'), PosixPath('../data/.DS_Store'), PosixPath('../data/data_information'), PosixPath('../data/data_documentation.txt'), PosixPath('../data/Employee_Information.db')]
[Retriever] Loading docs from: ../data/data_information
[Retriever] Built 10 vectorstores.

################################################################################
Manual example query: For each department, what is the average salary and number of employees?
[Router] Need to (re)plan.
[LLM] Calling model (attempt 1/3)...
[Planner] Plan created:
[
  {
    "id": 1,
    "type": "sql",
    "description": "Write a SQL query to calculate the average salary and number of employees for each department."
  }
]
[Executor] Executing step 1/1: type=sql, desc='Write a SQL query to calculate the average salary and number of employees for each department.'
[LLM] Calling