<a href="https://colab.research.google.com/github/HeyMahdy/ai-agents-playground/blob/main/SupervisorAgentPattern.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install -U langgraph langchain langchain-openai openai

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.3/153.3 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.0/75.0 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m948.4/948.4 kB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m447.5/447.5 kB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.9/43.9 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.8/56.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m216.7/216.7 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import getpass
from typing import TypedDict, Annotated, Sequence, Dict, List, Any

from langgraph.graph import StateGraph, END, add_messages
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.chat_models import init_chat_model

# Set OpenAI API key (prompts once if not present)
if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OPENAI_API_KEY: ")

# ----- State -----
class MessageState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]

class SupervisorState(MessageState):
    supervisor_decision: str
    task_assignments: Dict[str, List[str]]
    agent_outputs: Dict[str, Any]
    Schedule_Production_Agent_data: Dict[str, Any]
    Query_Agent_data: Dict[str, Any]
    Supply_Chain_Agent_data: Dict[str, Any]
    Monitor_Sensors_Agent_data: Dict[str, Any]
    workflow_stage: str
    iterationCount: int
    max_iteration: int
    final_output: str

# ----- LLM + Prompt -----
llm = init_chat_model("gpt-4o-mini", model_provider="openai")

prompt_template = ChatPromptTemplate.from_messages([
    ("system",
     """You are a Supervisor managing a team of agents:

1. Schedule_Production_Agent - handles product capacity calculation and schedule CSV questions
2. query_agent - performs database CRUD on the main database
3. Supply_Chain_Agent - checks for supply chain issues

Current State:
- Current Agent: {current_agent}
- Workflow Stage: {workflow_stage}
- Iteration: {iterationCount}
- Agent Outputs: {agent_outputs}
- Task Assignments: {task_assignments}
- Production Data: {Schedule_Production_Agent_data}
- Sensor Data: {Monitor_Sensors_Agent_data}
- Supply Chain Data: {Supply_Chain_Agent_data}

Instructions:
- Based on the current state and conversation, decide which agent should work next.
- Your selection must be EXACTLY one of:
  Schedule_Production_Agent
  Supply_Chain_Agent
  query_agent
- If all tasks are complete, reply exactly: Done
"""),
    ("human", "{task}")
])

def create_supervisor_chain():
    return prompt_template | llm | StrOutputParser()

def choose_next_agent_from_text(text: str) -> str:
    t = text.strip().lower()
    if t == "done":
        return "Done"
    if "schedule_production_agent" in t:
        return "Schedule_Production_Agent"
    if "supply_chain_agent" in t:
        return "Supply_Chain_Agent"
    if "query_agent" in t:
        return "query_agent"
    return "Schedule_Production_Agent"

# ----- Nodes -----
def supervisor_agent(state: SupervisorState) -> Dict[str, Any]:
    chain = create_supervisor_chain()
    last_human = ""
    for m in reversed(state["messages"]):
        if isinstance(m, HumanMessage):
            last_human = m.content
            break

    decision_text = chain.invoke({
        "current_agent": state.get("supervisor_decision", "") or "None",
        "workflow_stage": state["workflow_stage"],
        "iterationCount": state["iterationCount"],
        "agent_outputs": state["agent_outputs"],
        "task_assignments": state["task_assignments"],
        "Schedule_Production_Agent_data": state["Schedule_Production_Agent_data"],
        "Monitor_Sensors_Agent_data": state.get("Monitor_Sensors_Agent_data", {}),
        "Supply_Chain_Agent_data": state["Supply_Chain_Agent_data"],
        "task": last_human or "Decide next agent."
    })

    next_agent = choose_next_agent_from_text(decision_text)
    return {
        "supervisor_decision": next_agent,
        "messages": [AIMessage(content=f"supervisor_decision: {next_agent}")]
    }

def Schedule_Production_Agent(state: SupervisorState) -> Dict[str, Any]:
    result = "Production capacity is 100 units"
    new_outputs = dict(state["agent_outputs"])
    new_outputs["Schedule_Production_Agent"] = result
    new_sp_data = dict(state["Schedule_Production_Agent_data"])
    new_sp_data["capacity"] = 100
    return {
        "agent_outputs": new_outputs,
        "Schedule_Production_Agent_data": new_sp_data,
        "messages": [AIMessage(content=result)],
        "iterationCount": state["iterationCount"] + 1
    }

def Supply_Chain_Agent(state: SupervisorState) -> Dict[str, Any]:
    result = "Supply chain is good"
    new_outputs = dict(state["agent_outputs"])
    new_outputs["Supply_Chain_Agent"] = result
    new_sc_data = dict(state["Supply_Chain_Agent_data"])
    new_sc_data["status"] = "good"
    return {
        "agent_outputs": new_outputs,
        "Supply_Chain_Agent_data": new_sc_data,
        "messages": [AIMessage(content=result)],
        "iterationCount": state["iterationCount"] + 1
    }

def query_agent(state: SupervisorState) -> Dict[str, Any]:
    result = "crud operation done"
    new_outputs = dict(state["agent_outputs"])
    new_outputs["query_agent"] = result
    new_q_data = dict(state["Query_Agent_data"])
    new_q_data["last_operation"] = "crud"
    return {
        "agent_outputs": new_outputs,
        "Query_Agent_data": new_q_data,
        "messages": [AIMessage(content=result)],
        "iterationCount": state["iterationCount"] + 1
    }

def router(state: SupervisorState) -> Dict[str, Any]:
    return {"messages": [AIMessage(content=f"routing to: {state['supervisor_decision']}")]}

# ----- Graph -----
workflow = StateGraph(SupervisorState)
workflow.add_node("supervisor_agent", supervisor_agent)
workflow.add_node("Schedule_Production_Agent", Schedule_Production_Agent)
workflow.add_node("Supply_Chain_Agent", Supply_Chain_Agent)
workflow.add_node("query_agent", query_agent)
workflow.add_node("router", router)
workflow.add_edge("supervisor_agent", "router")

def route_decision(state: SupervisorState) -> str:
    return state["supervisor_decision"]

workflow.add_conditional_edges(
    "router",
    route_decision,
    {
        "Schedule_Production_Agent": "Schedule_Production_Agent",
        "Supply_Chain_Agent": "Supply_Chain_Agent",
        "query_agent": "query_agent",
        "Done": END,
    },
)

workflow.add_edge("Schedule_Production_Agent", "supervisor_agent")
workflow.add_edge("Supply_Chain_Agent", "supervisor_agent")
workflow.add_edge("query_agent", "supervisor_agent")
workflow.set_entry_point("supervisor_agent")
graph = workflow.compile()

# ----- Demo run -----

In [None]:
def run_query(user_input: str):
    initial_state: SupervisorState = {
        "messages": [HumanMessage(content=user_input)],
        "supervisor_decision": "",
        "task_assignments": {},
        "agent_outputs": {},
        "Schedule_Production_Agent_data": {},
        "Query_Agent_data": {},
        "Supply_Chain_Agent_data": {},
        "Monitor_Sensors_Agent_data": {},
        "workflow_stage": "initial",
        "iterationCount": 0,
        "max_iteration": 10,
        "final_output": ""
    }

    result = graph.invoke(initial_state)

    # Print conversation trace
    for msg in result["messages"]:
        print(f"{msg.type.upper()}: {msg.content}")

    # Print agent outputs for debugging
    print("\nFinal Agent Outputs:", result["agent_outputs"])
    return result


In [None]:
run_query("Calculate production capacity")

HUMAN: Calculate production capacity
AI: supervisor_decision: Schedule_Production_Agent
AI: routing to: Schedule_Production_Agent
AI: Production capacity is 100 units
AI: supervisor_decision: Done
AI: routing to: Done

Final Agent Outputs: {'Schedule_Production_Agent': 'Production capacity is 100 units'}


{'messages': [HumanMessage(content='Calculate production capacity', additional_kwargs={}, response_metadata={}, id='20efb3d6-2472-4269-b852-38b8171937d5'),
  AIMessage(content='supervisor_decision: Schedule_Production_Agent', additional_kwargs={}, response_metadata={}, id='7be56328-8599-452a-b0ec-19905a5bfc47'),
  AIMessage(content='routing to: Schedule_Production_Agent', additional_kwargs={}, response_metadata={}, id='2ab73228-3f83-4334-bb3a-cb5160ddb73c'),
  AIMessage(content='Production capacity is 100 units', additional_kwargs={}, response_metadata={}, id='3be640c0-08c6-49de-b2c6-9a576305bca9'),
  AIMessage(content='supervisor_decision: Done', additional_kwargs={}, response_metadata={}, id='c2646d97-e4a8-4df3-bc7d-2f55da0f43fc'),
  AIMessage(content='routing to: Done', additional_kwargs={}, response_metadata={}, id='21c3149a-0938-4706-ae38-c6f4b61fdd3a')],
 'supervisor_decision': 'Done',
 'task_assignments': {},
 'agent_outputs': {'Schedule_Production_Agent': 'Production capacity i