# LangGraph Agent Pipeline – Step by Step

This notebook runs the **LangGraph** Text-to-SQL pipeline (ReAct + CoT) and shows each step: agent reasoning and tool calls, then tool results, then agent again, until the final answer.

## 1. Setup

In [1]:
from dotenv import load_dotenv
load_dotenv()

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage

from src.data_loader import load_all_tables, get_sqlite_connection
from src.pipeline_langgraph import (
    build_graph,
    PipelineState,
    MAX_SQL_RETRIES,
)

print("✓ Imports loaded")



✓ Imports loaded


## 2. Load Data

In [2]:
tables = load_all_tables()

print("Loaded tables:")
for name, df in tables.items():
    print(f"  - {name}: {len(df)} rows, columns: {list(df.columns)}")

conn = get_sqlite_connection(tables)
print("\n✓ In-memory SQLite DB created")

Loaded tables:
  - Clients: 20 rows, columns: ['client_id', 'client_name', 'industry', 'country']
  - Invoices: 40 rows, columns: ['invoice_id', 'client_id', 'invoice_date', 'due_date', 'status', 'currency', 'fx_rate_to_usd']
  - InvoiceLineItems: 96 rows, columns: ['line_id', 'invoice_id', 'service_name', 'quantity', 'unit_price', 'tax_rate']

✓ In-memory SQLite DB created


## 3. Pick a Question and Build Initial State

In [3]:
question ='Group revenue by client country: for each country, compute the total billed amount in 2024 (including tax).'


initial_state: PipelineState = {
    "messages": [HumanMessage(content=question)],
    "question": question,
    "tables": tables,
    "conn": conn,
    "validation_retry_count": 0,
}

print(f"Question: {question}")
print("\n✓ Initial state ready (messages, question, tables, conn)")

Question: Group revenue by client country: for each country, compute the total billed amount in 2024 (including tax).

✓ Initial state ready (messages, question, tables, conn)


## 4. Build the Graph

In [4]:
graph = build_graph()

print("Graph nodes: agent, tools, answer")
print("✓ Graph compiled: agent ↔ tools (retry limit =", MAX_SQL_RETRIES, "), then answer node (natural language from table data)")

Graph nodes: agent, tools, answer
✓ Graph compiled: agent ↔ tools (retry limit = 3 ), then answer node (natural language from table data)


## 5. Run Step by Step (Stream)

In [5]:
def _summary(msg):
    """Short summary of a message for display."""
    if isinstance(msg, HumanMessage):
        return f"User: {msg.content[:80]}..." if len(str(msg.content)) > 80 else f"User: {msg.content}"
    if isinstance(msg, SystemMessage):
        return "System: [instructions]"
    if isinstance(msg, AIMessage):
        tool_calls = getattr(msg, 'tool_calls', None) or []
        if tool_calls:
            names = [tc.get('name', '') for tc in tool_calls]
            args_preview = []
            for tc in tool_calls:
                a = tc.get('args') or {}
                if a.get('sql'):
                    args_preview.append(f"sql={a['sql'][:60]}..." if len(a['sql']) > 60 else f"sql={a['sql']}")
                else:
                    args_preview.append(str(a)[:50])
            return f"Agent → tools: {names}  {args_preview}"
        return f"Agent: {str(msg.content)[:120]}..." if msg.content and len(str(msg.content)) > 120 else f"Agent: {msg.content or '(no text)'}"
    if isinstance(msg, ToolMessage):
        c = msg.content or ""
        if c.startswith('{"'):
            return f"Tool result: {c[:100]}..." if len(c) > 100 else f"Tool result: {c}"
        return f"Tool result: (length {len(c)} chars)"
    return str(type(msg).__name__)

step_num = 0
last_state = None

print("Streaming steps (agent → tools → agent → … → answer node):\n")

for state in graph.stream(initial_state, stream_mode="values"):
    step_num += 1
    last_state = state
    if state.get("final_answer") is not None:
        print(f"Step {step_num}: Answer node → natural language answer (length {len(state['final_answer'])} chars)")
    else:
        messages = state.get("messages") or []
        if messages:
            last_msg = messages[-1]
            summary = _summary(last_msg)
            retry = state.get("validation_retry_count", 0)
            print(f"Step {step_num}: {summary}")
            if retry > 0:
                print(f"         (validation_retry_count = {retry})")
    print()

print("--- Stream done ---")

Streaming steps (agent → tools → agent → … → answer node):

Step 1: User: Group revenue by client country: for each country, compute the total billed amou...

Step 2: Agent → tools: ['get_schema']  ['{}']

Step 3: Tool result: (length 7510 chars)

Step 4: Agent → tools: ['validate_sql']  ['sql=SELECT Clients.country, SUM(InvoiceLineItems.quantity * Invo...']

Step 5: Tool result: {"valid": true, "error_message": ""}

Step 6: Agent → tools: ['execute_sql']  ['sql=SELECT Clients.country, SUM(InvoiceLineItems.quantity * Invo...']

Step 7: Tool result: {"success": true, "row_count": 15, "preview": "        country  total_billed_amount\n0     Australia...

Step 8: Agent: {"question":"Group revenue by client country: for each country, compute the total billed amount in 2024 (including tax)....

retrieved data
        country  total_billed_amount
0     Australia               5757.5
1        Brazil               2051.5
2        Canada               2426.5
3        France               4062.5


## 6. Final Result

In [6]:
if last_state is None:
    print("No state (stream produced no steps).")
else:
    messages = last_state.get("messages") or []
    validation_retry_count = last_state.get("validation_retry_count", 0)
    # Answer from explicit answer node (based on table data from execute_sql)
    answer = last_state.get("final_answer")
    if not answer:
        for m in reversed(messages):
            if isinstance(m, ToolMessage) and m.content and not m.content.strip().startswith("{") and "STRUCTURAL KNOWLEDGE" not in m.content:
                answer = m.content
                break

    # Extract last SQL
    sql = None
    for m in reversed(messages):
        if isinstance(m, AIMessage) and getattr(m, "tool_calls", None):
            for tc in m.tool_calls or []:
                if tc.get("name") in ("execute_sql", "validate_sql") and (tc.get("args") or {}).get("sql"):
                    sql = tc["args"]["sql"]
                    break
            if sql:
                break

    print("SQL generated:")
    print(sql or "(none)")
    print()
    print("Validation retries used:", validation_retry_count)
    if validation_retry_count >= MAX_SQL_RETRIES:
        print("⚠ Validation failed after max retries.")
    print()
    print("Answer (from answer node, based on table data):")
    print(answer or "(no answer)")

conn.close()
print("\n✓ Connection closed.")

SQL generated:
SELECT Clients.country, SUM(InvoiceLineItems.quantity * InvoiceLineItems.unit_price * (1 + InvoiceLineItems.tax_rate)) AS total_billed_amount
FROM Clients
JOIN Invoices ON Clients.client_id = Invoices.client_id
JOIN InvoiceLineItems ON Invoices.invoice_id = InvoiceLineItems.invoice_id
WHERE strftime('%Y', Invoices.invoice_date) = '2024'
GROUP BY Clients.country;

Validation retries used: 0

Answer (from answer node, based on table data):
Here is the total billed amount in 2024, including tax, grouped by client country:

- Australia: 5757.5
- Brazil: 2051.5
- Canada: 2426.5
- France: 4062.5
- Germany: 3287.5
- India: 2401.5
- Ireland: 6106.0
- Netherlands: 2632.0
- New Zealand: 2037.5
- Norway: 5506.0
- Portugal: 5104.0
- Spain: 3310.0
- Switzerland: 3398.5
- UK: 7420.0
- USA: 22003.0

✓ Connection closed.
