In [None]:
import os, json
from datetime import datetime
from zoneinfo import ZoneInfo
from dotenv import load_dotenv
load_dotenv()

from typing_extensions import TypedDict, Literal, Annotated
from pydantic import BaseModel, Field

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages


# ---------------------------
# 1) LLM
# ---------------------------
llm = ChatOpenAI(model="gpt-4o", temperature=0)

NY_TZ = ZoneInfo("America/New_York")

def today_ny_str() -> str:
    return datetime.now(NY_TZ).date().isoformat()  # e.g. "2026-01-29"


# ---------------------------
# 2) Graph State
# ---------------------------
class State(TypedDict):
    messages: Annotated[list, add_messages]
    work: dict
    steps: int


# ---------------------------
# 3) Orchestrator decision schema
# ---------------------------
class ForecastPayload(BaseModel):
    horizon_days: int = Field(description="How many future days to forecast")


class OrchestratorDecision(BaseModel):
    action: Literal["CALL_FORECASTING", "CALL_EXPLORATION", "FINISH"] = Field(description="Next action")
    reasoning: str
    forecasting_payload: ForecastPayload | None = None
    exploration_query: str | None = None
    final_answer: str | None = None


planner = llm.with_structured_output(OrchestratorDecision)


# ---------------------------
# 4) Forecasting Agent (hardcoded)
# ---------------------------
def forecasting_agent(payload: ForecastPayload) -> dict:
    horizon_days = int(payload.horizon_days)
    forecast = [
        {"day": 1, "date": "2026-01-30", "forecast_ticket_count": 31},
        {"day": 2, "date": "2026-01-31", "forecast_ticket_count": 35},
    ][:horizon_days]

    return {
        "agent": "forecasting_agent",
        "payload_received": payload.model_dump(),
        "forecast": forecast,
    }


# ---------------------------
# 5) Exploration Agent: router + 2 tools
# ---------------------------
# Tool A: DB tool expects dict payload
def exploration_db_tool(payload: dict) -> dict:
    # HARD-CODED: pretend DB query result
    # payload example: {"Date":"2026-01-29","Forecast":2}
    date = payload.get("Date")
    days = payload.get("Forecast", 2)
    return {
        "tool": "db_tool",
        "payload_received": payload,
        "result": {
            "ticket_count": [
                {"date": date, "count": 34},
                {"date": "NEXT_DAY", "count": 29},  # dummy
            ][: max(1, int(days))]
        },
    }

# Tool B: Knowledge tool expects string query
def exploration_knowledge_tool(query: str) -> dict:
    # HARD-CODED: pretend KB answer
    return {
        "tool": "knowledge_tool",
        "query_received": query,
        "answer": "NET_500 usually indicates a network/server-side error. (Hardcoded KB answer)"
    }


# Exploration agent tool-choice schema (structured)
class DBPayload(BaseModel):
    Date: str = Field(description="Start date in YYYY-MM-DD")
    Forecast: int = Field(description="Number of days requested (e.g., 2 for next 2 days)")


class ExplorationDecision(BaseModel):
    tool: Literal["DB", "KNOWLEDGE"] = Field(description="Which exploration tool to use")
    reasoning: str = Field(description="1 line reason")
    # If KNOWLEDGE:
    knowledge_query: str | None = None
    # If DB:
    db_payload: DBPayload | None = None


exploration_router = llm.with_structured_output(ExplorationDecision)

EXPLORATION_SYSTEM = SystemMessage(content=f"""
You are the Exploration Agent. You have TWO tools:

1) DB tool (DB):
   Use when user asks about counts/stats/history from tickets DB, like "ticket count for last 2 days".
   You MUST output db_payload with:
     - Date (YYYY-MM-DD)
     - Forecast (int days)
   IMPORTANT: If user says "from today" or "today till N days", set Date to today's date.
   Today's date (America/New_York) is: {today_ny_str()}

2) Knowledge tool (KNOWLEDGE):
   Use for explanations/definitions like "what does NET_500 mean?"
   You MUST output knowledge_query as plain string.

Return only the structured decision.
""")


def exploration_agent(query: str) -> dict:
    """
    This is the Exploration Agent that decides which tool to use,
    formats the input correctly, then calls the chosen tool.
    """
    decision = exploration_router.invoke([
        EXPLORATION_SYSTEM,
        HumanMessage(content=query)
    ])

    debug = {
        "agent": "exploration_agent",
        "router_decision": {
            "tool": decision.tool,
            "reasoning": decision.reasoning,
            "knowledge_query": decision.knowledge_query,
            "db_payload": decision.db_payload.model_dump() if decision.db_payload else None
        }
    }

    if decision.tool == "KNOWLEDGE":
        q = decision.knowledge_query or query
        tool_out = exploration_knowledge_tool(q)
        return {**debug, "tool_output": tool_out}

    # DB
    payload = decision.db_payload.model_dump() if decision.db_payload else {"Date": today_ny_str(), "Forecast": 2}
    tool_out = exploration_db_tool(payload)
    return {**debug, "tool_output": tool_out}


# ---------------------------
# 6) Orchestrator node (loop)
# ---------------------------
ORCH_SYSTEM = SystemMessage(content="""
You are an orchestrator agent that plans step-by-step.

You can call:
- Forecasting Agent: for future forecasts
- Exploration Agent: for DB stats or knowledge lookups

Use state.work to see what you already collected.
Stop when you have enough and return FINISH with final_answer.
Max 5 steps.
""")

def orchestrator_node(state: State):
    work = state.get("work", {})
    print('work :',work)
    steps = state.get("steps", 0)
    print('step:',steps)

    if steps >= 5:
        return {"messages": [AIMessage(content="Stopped after 5 steps to avoid looping.")], "work": work, "steps": steps}

    work_json = json.dumps(work, indent=2, default=str)
    print('What i have already collected(work_json) :',work_json)

    print("What is planner is invoking: ",[ORCH_SYSTEM] + state["messages"] + [SystemMessage(content=f"Current state.work JSON:\n{work_json}")]) 


    decision = planner.invoke(
        [ORCH_SYSTEM] + state["messages"] + [SystemMessage(content=f"Current state.work JSON:\n{work_json}")]
    ) # planner likes the literal output whether "CALL_FORECASTING", "CALL_EXPLORATION", "FINISH"

    print("My decision is ",decision)

    debug_msg = (
        f"DEBUG Orchestrator Decision:\n"
        f"action={decision.action}\n"
        f"reasoning={decision.reasoning}\n"
        f"forecasting_payload={decision.forecasting_payload.model_dump() if decision.forecasting_payload else None}\n"
        f"exploration_query={decision.exploration_query}\n"
    )
    print(debug_msg)

    if decision.action == "FINISH":
        return {
            "messages": [AIMessage(content=debug_msg), AIMessage(content=decision.final_answer or "Finished.")],
            "work": work,
            "steps": steps + 1
        }

    new_work = dict(work)
    print("what is inside decision.forecasting_payload :",decision.forecasting_payload)
    print("What is inside new work",new_work)
    if decision.action == "CALL_FORECASTING":
        new_work["next_forecasting_payload"] = (
            decision.forecasting_payload.model_dump() if decision.forecasting_payload else {"horizon_days": 2}
        )
    if decision.action == "CALL_EXPLORATION":
        new_work["next_exploration_query"] = decision.exploration_query or "Get ticket count from today till 2 days"

    return {"messages": [AIMessage(content=debug_msg)], "work": new_work, "steps": steps + 1}


# ---------------------------
# 7) Agent-call nodes
# ---------------------------
def call_forecasting_node(state: State):
    work = dict(state.get("work", {}))
    payload = ForecastPayload(**work.get("next_forecasting_payload", {"horizon_days": 2}))
    result = forecasting_agent(payload)

    work["forecast_result"] = result
    work.pop("next_forecasting_payload", None)

    return {"messages": [AIMessage(content=f"DEBUG Forecasting Agent returned:\n{json.dumps(result, indent=2)}")],
            "work": work, "steps": state.get("steps", 0)}


def call_exploration_node(state: State):
    work = dict(state.get("work", {}))
    q = work.get("next_exploration_query", "Get ticket count from today till 2 days")
    result = exploration_agent(q)

    work["exploration_result"] = result
    work.pop("next_exploration_query", None)

    return {"messages": [AIMessage(content=f"DEBUG Exploration Agent returned:\n{json.dumps(result, indent=2)}")],
            "work": work, "steps": state.get("steps", 0)}


def next_step_router(state: State):
    work = state.get("work", {})
    if "next_forecasting_payload" in work:
        return "call_forecasting"
    if "next_exploration_query" in work:
        return "call_exploration"
    return END


# ---------------------------
# 8) Build Graph
# ---------------------------
builder = StateGraph(State)
builder.add_node("orchestrator", orchestrator_node)
builder.add_node("call_forecasting", call_forecasting_node)
builder.add_node("call_exploration", call_exploration_node)

builder.add_edge(START, "orchestrator")
builder.add_conditional_edges(
    "orchestrator",
    next_step_router,
    {"call_forecasting": "call_forecasting", "call_exploration": "call_exploration", END: END}
)
builder.add_edge("call_forecasting", "orchestrator")
builder.add_edge("call_exploration", "orchestrator")

app = builder.compile()


# ---------------------------
# 9) Test
# ---------------------------
if __name__ == "__main__":
    question = "from today till 2 days give me ticket count. also what does NET_500 mean?"
    init_state: State = {"messages": [HumanMessage(content=question)], "work": {}, "steps": 0}
    result = app.invoke(init_state)

    print("\n================ FINAL OUTPUT ================\n")
    for m in result["messages"]:
        print(f"{type(m).__name__}:\n{m.content}\n")

    print("\n================ FINAL WORK JSON ================\n")
    print(json.dumps(result["work"], indent=2, default=str))
    # Show the workflow
    display(Image(router_workflow.get_graph().draw_mermaid_png()))
