In [None]:
# Step 1: Import Necessary Libraries

import os
import json
from dotenv import load_dotenv
load_dotenv()

from typing_extensions import Literal, TypedDict
from pydantic import BaseModel, Field
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import StateGraph, START, END
from langchain_openai import ChatOpenAI

openai_api_key = os.getenv("OPENAI_API_KEY")

# Step 2 : Initialize the Bedrock LLM
llm = ChatOpenAI(model="gpt-4o")

#----------------------------------------------- Step 3: Create Router Schema---------------------------------------------------------------
#-----------------------------------------------1. Orchestrator Agent Routing---------------------------------------------------------------
class Route(BaseModel):
    step: Literal["EXPLORATION", "FORECASTING"] = Field(description="Which agent to run next")
router = llm.with_structured_output(Route)

# ----------------------------------------------Step 5: Tools--------------------------------------------------------------------------------
# ------------------------------------- Tool 1 : MYSQL Query Tool ---------------------------------------------------------------------------
def mysql_query_tool():
    engine = create_engine(f"mysql+mysqlconnector://{mysql_user}:{mysql_password}@{mysql_host}/{mysql_db}")
    return SQLDatabase(engine)


#--------------------------------------------------Step 6: Create Nodes--------------------------------------------------------------------
# -------------------------------------Node 1 : Orchestrator Agent (Analyst Agent)----------------------------------------------------------

def orchestrator_router(state: State):
    decision = router.invoke([
        SystemMessage(content=(
            "Route the user request to:\n"
            "- EXPLORATION: DB questions, stats, analysis, explanations, RAG info lookups\n"
            "- FORECASTING: prediction/forecast/future counts/next days values\n"
            "Return ONLY EXPLORATION or FORECASTING."
        )),
        HumanMessage(content=state["input"])
    ])
    return {"decision": decision.step}

# Conditional edge function to route to the appropriate node
def route_decision(state: State):
    return "exploration_agent" if state["decision"] == "EXPLORATION" else "forecasting_agent"

# ------------------------------------- Node 2 : Exploration Agent ---------------------------------------------------------------------------
def exploration_agent(state: State):
    db = mysql_query_tool()
    toolkit = SQLDatabaseToolkit(db=db, llm=llm)
    agent = create_sql_agent(
        llm=llm,
        toolkit=toolkit,
        verbose=True,
        handle_parsing_errors=True,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        agent_kwargs={
            "prefix": (
                "You are a helpful data assistant with access to an SQL database.\n"
                "Answer all questions using the database.\n"
                "If you do not know the answer, say 'I don't know'."
            )},
        agent_executor_kwargs={ "handle_parsing_errors": True })
    result = agent.invoke(state['input'])

# ------------------------------------- Node 3 : Forecasting Agent ---------------------------------------------------------------------------

def forecasting_agent(state: State):
    payload = {"query": state["input"], "horizon": 7}
    fc = sagemaker_forecast_tool(payload)
    return {"output": json.dumps({"agent": "forecasting", "tool": "sagemaker", "payload": payload, "result": fc}, default=str)}

#--------------------------------------------Step 7: Build a complete graph--------------------------------------------------------------------
# State schema
class State(TypedDict):
    input:str
    decision:str
    output:str

# Build graph
builder = StateGraph(State)
builder.add_node("orchestrator_router", orchestrator_router)
builder.add_node("exploration_agent", exploration_agent)
builder.add_node("forecasting_agent", forecasting_agent)

builder.add_edge(START, "orchestrator_router")
builder.add_conditional_edges("orchestrator_router", route_decision, {
    "exploration_agent": "exploration_agent",
    "forecasting_agent": "forecasting_agent",
})
builder.add_edge("exploration_agent", END)
builder.add_edge("forecasting_agent", END)

app = builder.compile()

In [3]:
import os, json
from dotenv import load_dotenv
load_dotenv()

from typing_extensions import TypedDict, Literal, Annotated
from pydantic import BaseModel, Field

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages


llm = ChatOpenAI(model="gpt-4o", temperature=0)

class State(TypedDict):
    messages: Annotated[list, add_messages]
    work: dict
    steps: int


# âœ… Make payload a strict model (no raw dict)
class ForecastPayload(BaseModel):
    horizon_days: int = Field(description="How many future days to forecast")


class OrchestratorDecision(BaseModel):
    action: Literal["CALL_FORECASTING", "CALL_EXPLORATION", "FINISH"] = Field(
        description="Next action for the orchestrator"
    )
    reasoning: str = Field(description="1-2 lines why you chose this action")
    forecasting_payload: ForecastPayload | None = Field(default=None)
    exploration_query: str | None = Field(default=None)
    final_answer: str | None = Field(default=None)


planner = llm.with_structured_output(OrchestratorDecision)


def forecasting_agent(payload: ForecastPayload) -> dict:
    horizon_days = int(payload.horizon_days)
    forecast = [
        {"day": 1, "date": "2026-01-30", "forecast_ticket_count": 31},
        {"day": 2, "date": "2026-01-31", "forecast_ticket_count": 35},
    ][:horizon_days]

    return {
        "agent": "forecasting_agent",
        "payload_received": payload.model_dump(),
        "forecast": forecast,
    }


def exploration_agent(query: str) -> dict:
    return {
        "agent": "exploration_agent",
        "query_received": query,
        "last_2_days_ticket_count": [
            {"date": "2026-01-28", "ticket_count": 29},
            {"date": "2026-01-29", "ticket_count": 34},
        ],
    }


ORCH_SYSTEM = SystemMessage(content="""
You are an orchestrator agent that plans step-by-step.

You can call two agents:
1) Forecasting Agent (CALL_FORECASTING):
   Use if user asks for forecast/predict/future/next days.
   Provide forecasting_payload with:
   - horizon_days (int)

2) Exploration Agent (CALL_EXPLORATION):
   Use if user asks for past/current counts, DB stats, summaries.
   Provide exploration_query as a short DB question.

Use state.work to see what you already collected:
- If forecast_result already exists, do NOT call forecasting again.
- If db_result already exists, do NOT call exploration again.

Stop when you have enough and return FINISH with final_answer.
Max 5 steps.
""")


def orchestrator_node(state: State):
    work = state.get("work", {})
    steps = state.get("steps", 0)

    if steps >= 5:
        return {
            "messages": [AIMessage(content="Stopped after 5 steps to avoid looping.")],
            "work": work,
            "steps": steps
        }

    work_json = json.dumps(work, indent=2, default=str)

    decision = planner.invoke(
        [ORCH_SYSTEM]
        + state["messages"]
        + [SystemMessage(content=f"Current state.work JSON:\n{work_json}")]
    )

    debug_msg = (
        f"DEBUG Orchestrator Decision:\n"
        f"action={decision.action}\n"
        f"reasoning={decision.reasoning}\n"
        f"forecasting_payload={decision.forecasting_payload.model_dump() if decision.forecasting_payload else None}\n"
        f"exploration_query={decision.exploration_query}\n"
    )

    if decision.action == "FINISH":
        return {
            "messages": [
                AIMessage(content=debug_msg),
                AIMessage(content=decision.final_answer or "Finished.")
            ],
            "work": work,
            "steps": steps + 1
        }

    new_work = dict(work)

    if decision.action == "CALL_FORECASTING":
        new_work["next_forecasting_payload"] = (
            decision.forecasting_payload.model_dump()
            if decision.forecasting_payload
            else {"horizon_days": 2}
        )

    if decision.action == "CALL_EXPLORATION":
        new_work["next_exploration_query"] = decision.exploration_query or "Get ticket count for last 2 days"

    return {
        "messages": [AIMessage(content=debug_msg)],
        "work": new_work,
        "steps": steps + 1
    }


def call_forecasting_node(state: State):
    work = dict(state.get("work", {}))
    payload_dict = work.get("next_forecasting_payload", {"horizon_days": 2})
    payload = ForecastPayload(**payload_dict)

    result = forecasting_agent(payload)

    work["forecast_result"] = result
    work.pop("next_forecasting_payload", None)

    return {
        "messages": [AIMessage(content=f"DEBUG Forecasting Agent returned:\n{json.dumps(result, indent=2)}")],
        "work": work,
        "steps": state.get("steps", 0)
    }


def call_exploration_node(state: State):
    work = dict(state.get("work", {}))
    query = work.get("next_exploration_query", "Get ticket count for last 2 days")

    result = exploration_agent(query)

    work["db_result"] = result
    work.pop("next_exploration_query", None)

    return {
        "messages": [AIMessage(content=f"DEBUG Exploration Agent returned:\n{json.dumps(result, indent=2)}")],
        "work": work,
        "steps": state.get("steps", 0)
    }


def next_step_router(state: State):
    work = state.get("work", {})
    if "next_forecasting_payload" in work:
        return "call_forecasting"
    if "next_exploration_query" in work:
        return "call_exploration"
    return END


builder = StateGraph(State)
builder.add_node("orchestrator", orchestrator_node)
builder.add_node("call_forecasting", call_forecasting_node)
builder.add_node("call_exploration", call_exploration_node)

builder.add_edge(START, "orchestrator")
builder.add_conditional_edges(
    "orchestrator",
    next_step_router,
    {
        "call_forecasting": "call_forecasting",
        "call_exploration": "call_exploration",
        END: END
    }
)
builder.add_edge("call_forecasting", "orchestrator")
builder.add_edge("call_exploration", "orchestrator")

app = builder.compile()


if __name__ == "__main__":
    user_question = "show me the forecast for next two days and show me the ticket count for last 2 days"
    init_state: State = {"messages": [HumanMessage(content=user_question)], "work": {}, "steps": 0}
    result = app.invoke(init_state)

    print("\n================ FINAL OUTPUT ================\n")
    for m in result["messages"]:
        print(f"{type(m).__name__}:\n{m.content}\n")

    print("\n================ FINAL WORK JSON ================\n")
    print(json.dumps(result["work"], indent=2, default=str))




HumanMessage:
show me the forecast for next two days and show me the ticket count for last 2 days

AIMessage:
DEBUG Orchestrator Decision:
action=CALL_FORECASTING
reasoning=The user requested a forecast for the next two days, so I will call the forecasting agent first.
forecasting_payload={'horizon_days': 2}
exploration_query=None


AIMessage:
DEBUG Forecasting Agent returned:
{
  "agent": "forecasting_agent",
  "payload_received": {
    "horizon_days": 2
  },
  "forecast": [
    {
      "day": 1,
      "date": "2026-01-30",
      "forecast_ticket_count": 31
    },
    {
      "day": 2,
      "date": "2026-01-31",
      "forecast_ticket_count": 35
    }
  ]
}

AIMessage:
DEBUG Orchestrator Decision:
action=CALL_EXPLORATION
reasoning=I have obtained the forecast for the next two days. Now, I need to get the ticket count for the last two days as requested by the user.
forecasting_payload=None
exploration_query=ticket count for last 2 days.


AIMessage:
DEBUG Exploration Agent returned: