### 1. Setup and Installation

In [1]:
!pip install langchain-openai langchain-core langgraph fastapi uvicorn requests pydantic -q

In [8]:
import os
import sqlite3
import datetime
import threading
import requests
import uvicorn
from typing import TypedDict, Optional, List, Any, Literal
from contextlib import contextmanager
from getpass import getpass

from fastapi import FastAPI
from pydantic import BaseModel, Field

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END

from database_setup import DatabaseSetup

### 2. OpenRouter and Environment Configuration

In [4]:
# --- OpenRouter Configuration ---
os.environ["OPENROUTER_API_KEY"] = "hidden_api_key"

OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY")
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
DEFAULT_OPENROUTER_MODEL = "x-ai/grok-4.1-fast:free"

# LangChain's ChatOpenAI expects the base URL ("/v1")
OPENROUTER_BASE_URL = OPENROUTER_URL.rsplit("/chat/completions", 1)[0]

# Initialize a single LLM instance for both routing and support
llm = ChatOpenAI(
    base_url=OPENROUTER_BASE_URL,
    api_key=OPENROUTER_API_KEY,
    model=DEFAULT_OPENROUTER_MODEL,
    temperature=0.0, # Using low temperature for consistent routing/support
)

router_llm = llm
support_llm = llm

print(f"LLMs configured for OpenRouter model: {DEFAULT_OPENROUTER_MODEL}")
print(f"OpenRouter Base URL: {OPENROUTER_BASE_URL}")

LLMs configured for OpenRouter model: x-ai/grok-4.1-fast:free
OpenRouter Base URL: https://openrouter.ai/api/v1


### 3. Database Creation and Sample Data

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
%cd /content/drive/MyDrive/genai_hw_MCP_A2A

/content/drive/MyDrive/genai_hw_MCP_A2A


In [9]:
def init_db(db_path: str = "support.db") -> sqlite3.Connection:
    """
    Initialize the SQLite database using DatabaseSetup
    and return a shared connection.
    """
    # Run setup (create tables, sample data, etc.)
    setup = DatabaseSetup(db_path=db_path)
    setup.connect()
    setup.create_tables()
    setup.create_triggers()
    setup.insert_sample_data()
    setup.close()

    conn = sqlite3.connect(db_path, check_same_thread=False)
    conn.row_factory = sqlite3.Row
    return conn

db_conn = init_db()

Connected to database: support.db
Tables created successfully!
Triggers created successfully!
Sample data inserted successfully!
  - 15 customers added
  - 25 tickets added
Database connection closed.


In [10]:
cur = db_conn.cursor()
cur.execute("PRAGMA table_info(customers)")
for r in cur.fetchall():
    print(dict(r))

{'cid': 0, 'name': 'id', 'type': 'INTEGER', 'notnull': 0, 'dflt_value': None, 'pk': 1}
{'cid': 1, 'name': 'name', 'type': 'TEXT', 'notnull': 1, 'dflt_value': None, 'pk': 0}
{'cid': 2, 'name': 'email', 'type': 'TEXT', 'notnull': 0, 'dflt_value': None, 'pk': 0}
{'cid': 3, 'name': 'phone', 'type': 'TEXT', 'notnull': 0, 'dflt_value': None, 'pk': 0}
{'cid': 4, 'name': 'status', 'type': 'TEXT', 'notnull': 1, 'dflt_value': "'active'", 'pk': 0}
{'cid': 5, 'name': 'created_at', 'type': 'TIMESTAMP', 'notnull': 0, 'dflt_value': 'CURRENT_TIMESTAMP', 'pk': 0}
{'cid': 6, 'name': 'updated_at', 'type': 'TIMESTAMP', 'notnull': 0, 'dflt_value': 'CURRENT_TIMESTAMP', 'pk': 0}


In [11]:
cur = db_conn.cursor()
cur.execute("PRAGMA table_info(tickets)")
rows = cur.fetchall()

for r in rows:
    print(dict(r))

{'cid': 0, 'name': 'id', 'type': 'INTEGER', 'notnull': 0, 'dflt_value': None, 'pk': 1}
{'cid': 1, 'name': 'customer_id', 'type': 'INTEGER', 'notnull': 1, 'dflt_value': None, 'pk': 0}
{'cid': 2, 'name': 'issue', 'type': 'TEXT', 'notnull': 1, 'dflt_value': None, 'pk': 0}
{'cid': 3, 'name': 'status', 'type': 'TEXT', 'notnull': 1, 'dflt_value': "'open'", 'pk': 0}
{'cid': 4, 'name': 'priority', 'type': 'TEXT', 'notnull': 1, 'dflt_value': "'medium'", 'pk': 0}
{'cid': 5, 'name': 'created_at', 'type': 'DATETIME', 'notnull': 0, 'dflt_value': 'CURRENT_TIMESTAMP', 'pk': 0}


### 4. FastAPI MCP Server and DB Functions

In [12]:
# ---------- DB functions ----------

def get_customer(cid: int):
    cur = db_conn.cursor()
    cur.execute("SELECT * FROM customers WHERE id = ?", (cid,))
    row = cur.fetchone()
    if row is None:
        return None
    return dict(row)


def list_active_customers():
    cur = db_conn.cursor()
    cur.execute("SELECT * FROM customers WHERE status = 'active'")
    rows = cur.fetchall()
    return [dict(r) for r in rows]


def update_customer(cid: int, fields: dict):
    allowed_fields = {"name", "email", "phone", "status"}

    cur = db_conn.cursor()
    for k, v in fields.items():
        if k not in allowed_fields:
            continue
        cur.execute(f"UPDATE customers SET {k} = ? WHERE id = ?", (v, cid))

    db_conn.commit()
    return True


def create_ticket(cid: int, description: str, severity: str):
    cur = db_conn.cursor()
    ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    cur.execute(
        """
        INSERT INTO tickets (customer_id, issue, status, priority, created_at)
        VALUES (?, ?, ?, ?, ?)
        """,
        (cid, description, "open", severity, ts),
    )
    db_conn.commit()
    return True


def get_history(cid: int):
    cur = db_conn.cursor()
    cur.execute("SELECT * FROM tickets WHERE customer_id = ?", (cid,))
    rows = cur.fetchall()
    return [dict(r) for r in rows]


In [22]:
app = FastAPI()

# ---------------- MCP SPEC ------------------

@app.get("/tools/list")
def tools_list():
    return {
        "tools": [
            "get_customer",
            "list_active_customers",
            "update_customer",
            "create_ticket",
            "get_history",
        ]
    }

class ToolInput(BaseModel):
    tool_name: str
    arguments: dict

@app.post("/tools/call")
def tools_call(body: ToolInput):
    tool = body.tool_name
    args = body.arguments

    if tool == "get_customer":
        return {"result": get_customer(args["cid"])}

    if tool == "list_active_customers":
        return {"result": list_active_customers()}

    if tool == "update_customer":
        return {"result": update_customer(args["cid"], args["fields"])}

    if tool == "create_ticket":
        return {
            "result": create_ticket(
                args["cid"], args["description"], args["severity"]
            )
        }

    if tool == "get_history":
        return {"result": get_history(args["cid"])}

    return {"error": "tool not found"}

### 5. Start the Server and Define Global State

In [23]:
def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000)

threading.Thread(target=run_server, daemon=True).start()

print("MCP Server running at http://127.0.0.1:8000")

class GlobalState(TypedDict):
    messages: List[Any]
    intent: Optional[str]
    cid: Optional[int]
    customer: Optional[dict]
    history: Optional[list]
    report: Optional[dict]

MCP Server running at http://127.0.0.1:8000


### 6. Agent Nodes (Router, Data, and Support)

In [24]:
import requests
from typing import Any, Dict, Optional, List

from langchain_core.messages import HumanMessage, SystemMessage

# Helper function to call the local MCP server
def call_tool(name: str, args: Dict[str, Any]) -> Any:
    """
    Call a tool exposed by the local MCP server and return its 'result' field.
    """
    resp = requests.post(
        "http://127.0.0.1:8000/tools/call",
        json={"tool_name": name, "arguments": args},
        timeout=10,
    )
    resp.raise_for_status()
    data = resp.json()
    if "result" not in data:
        raise RuntimeError(f"No 'result' field in MCP response: {data}")
    return data["result"]


def _get_last_human_message(state: "GlobalState") -> Optional[HumanMessage]:
    """Utility: get the last HumanMessage from the state, if any."""
    for msg in reversed(state["messages"]):
        if isinstance(msg, HumanMessage):
            return msg
    return None


def router_node(state: "GlobalState") -> "GlobalState":
    last_human = _get_last_human_message(state)
    if last_human is None:
        state["messages"].append(
            SystemMessage(content="[Router] No HumanMessage found; skipping routing.")
        )
        return {**state, "intent": None, "cid": None}

    txt = last_human.content.lower()

    # classify scenarios
    if "upgrade" in txt:
        intent = "upgrade_case"
    elif "refund" in txt or "charged" in txt:
        intent = "billing_issue"
    elif "active customers" in txt and "open tickets" in txt:
        intent = "report_case"
    elif "update my email" in txt:
        intent = "multi_op"
    else:
        intent = "basic_lookup"

    # extract cid
    cid: Optional[int] = None
    for w in txt.replace(",", " ").split():
        if w.isdigit():
            cid = int(w)
            break

    state["messages"].append(
        SystemMessage(content=f"[Router] intent={intent}, cid={cid}")
    )

    return {**state, "intent": intent, "cid": cid}


def data_node(state: "GlobalState") -> "GlobalState":
    intent = state.get("intent")
    cid = state.get("cid")

    state["messages"].append(SystemMessage(content="[DataAgent] DB ops begin"))

    # 1. simple cases (lookup customer info and history)
    if intent in ("basic_lookup", "upgrade_case", "billing_issue") and cid:
        cust = call_tool("get_customer", {"cid": cid})
        hist = call_tool("get_history", {"cid": cid})
        state["messages"].append(
            SystemMessage(
                content=f"[DataAgent] fetched customer + {len(hist) if hist else 0} tickets"
            )
        )
        return {**state, "customer": cust, "history": hist}

    # 2. report case (aggregate data from multiple customers)
    if intent == "report_case":
        active_customers = call_tool("list_active_customers", {})
        open_summary: List[Dict[str, Any]] = []

        for c in active_customers or []:
            hist = call_tool("get_history", {"cid": c["id"]})
            open_tickets = [h for h in (hist or []) if h.get("status") == "open"]
            if open_tickets:
                open_summary.append(
                    {"customer": c, "open_issues": open_tickets}
                )

        return {**state, "report": {"open_summary": open_summary}}

    # 3. multi-op (update customer field + history lookup)
    if intent == "multi_op" and cid:
        last_human = _get_last_human_message(state)
        text = last_human.content if last_human else ""
        new_email: Optional[str] = None

        for w in text.split():
            if "@" in w:
                new_email = w.replace(",", "")
                break

        if new_email:
            call_tool(
                "update_customer",
                {"cid": cid, "fields": {"email": new_email}},
            )

        hist = call_tool("get_history", {"cid": cid})
        # Refetch customer data to confirm the update
        cust = call_tool("get_customer", {"cid": cid})
        return {**state, "customer": cust, "history": hist}

    state["messages"].append(SystemMessage(content="[DataAgent] no-op"))
    return state


def support_node(state: "GlobalState") -> "GlobalState":
    prompt: List[Any] = [
        SystemMessage(
            content="You are a helpful support specialist. Be brief and clear."
        )
    ]

    # Provide context to the LLM
    if state.get("customer") is not None:
        prompt.append(SystemMessage(content=f"Customer: {state['customer']}"))

    if state.get("history") is not None:
        prompt.append(SystemMessage(content=f"History: {state['history']}"))

    if state.get("report") is not None:
        prompt.append(SystemMessage(content=f"Report: {state['report']}"))

    last_human = _get_last_human_message(state)
    if last_human:
        prompt.append(last_human)

    # Invoke the OpenRouter LLM
    reply = support_llm.invoke(prompt)  # assumes support_llm is defined elsewhere

    state["messages"].append(reply)
    state["messages"].append(SystemMessage(content="[SupportAgent] done"))

    return state

### 7. LangGraph Definition and Compilation

In [25]:
# Build the StateGraph with the updated nodes
graph = StateGraph(GlobalState)

# Register nodes
graph.add_node("router", router_node)
graph.add_node("data", data_node)
graph.add_node("support", support_node)

# Define edges (flow)
graph.add_edge(START, "router")      # Start → Router
graph.add_edge("router", "data")     # Router → DataAgent
graph.add_edge("data", "support")    # DataAgent → SupportAgent
graph.add_edge("support", END)       # SupportAgent → End

# Compile into an executable app
app = graph.compile()

### 8. Execution Function

In [26]:
def run_query(user_query: str):
    # Pretty print header
    print("\n" + "=" * 50)
    print(f"USER QUERY: {user_query}")
    print("=" * 50)

    # Initial GlobalState
    initial_state = {
        "messages": [HumanMessage(content=user_query)],
        "intent": None,
        "cid": None,
        "customer": None,
        "history": None,
        "report": None,
    }

    # Run the LangGraph workflow
    final_state = app.invoke(initial_state)

    # Print all messages in order
    for msg in final_state.get("messages", []):
        role = msg.__class__.__name__
        print(f"{role}: {msg.content}")

    return final_state


### 9. Test Queries

In [27]:
def run_all_tests():
    tests = [
        "Get customer information for ID 5",
        "I'm customer 5 and need help upgrading my account",
        "Show me all active customers who have open tickets",
        "I'm customer 5. I've been charged twice, please refund immediately!",
        "I'm customer 5, update my email to new@test.com and show my ticket history",
    ]

    for i, q in enumerate(tests, start=1):
        print(f"\n\n##### Test {i} #####")
        run_query(q)

if __name__ == "__main__":
    run_all_tests()



##### Test 1 #####

USER QUERY: Get customer information for ID 5
INFO:     127.0.0.1:57000 - "POST /tools/call HTTP/1.1" 200 OK
INFO:     127.0.0.1:57014 - "POST /tools/call HTTP/1.1" 200 OK
HumanMessage: Get customer information for ID 5
SystemMessage: [Router] intent=basic_lookup, cid=5
SystemMessage: [DataAgent] DB ops begin
SystemMessage: [DataAgent] fetched customer + 11 tickets
AIMessage: **Customer ID: 5**  
- **Name**: Charlie Brown  
- **Email**: new@email.com  
- **Phone**: +1-555-0105  
- **Status**: active  
- **Created**: 2025-12-01 23:18:40  
- **Updated**: 2025-12-01 23:42:33  

**Open Tickets (3 total open)**:  
- ID 8: Email notifications not being received (medium)  
- ID 35: Email notifications not being received (medium)  
- ID 55: Request to upgrade account (medium)  
- ID 63: Email notifications not being received (medium)  
- ID 91: Email notifications not being received (medium)  
- ID 116: Email notifications not being received (medium)  

**Resolved Tickets