In [None]:
# main_workflow.py

import re
import json
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from agent_prompts import *
from llm_model import OllamaModel
from utils import ResultStep, StageExecute

# Instantiate base model
ollama = OllamaModel(model_name="llama3.2", temperature=0)

# --- AGENTS ---
planner_agent = ollama.model_(PLANNER_PROMPT)
orchestrator_agent = ollama.model_(ORCHESTRATOR_PROMPT)
inspector_agent = ollama.model_(INSPECTOR_PROMPT)
aggregator_agent = ollama.model_(AGGREGATOR_PROMPT)

# Worker agents
worker_agents = {
    "content ordering": ollama.model_(WORKER_PROMPT + "\n" + content_ordering),
    "text structuring": ollama.model_(WORKER_PROMPT + "\n" + text_structuring),
    "surface realization": ollama.model_(WORKER_PROMPT + "\n" + surface_realization),
}

# --- NODES ---

def run_planner(input: str) -> StageExecute:
    plan_output = planner_agent.invoke({"input": input}).content
    print(f"Planner output: {plan_output}")

    try:
        thought, plans = re.findall(
            r"Thought:\s*(.*?)\s*Plan:\s```json*(.*?)\s```", plan_output, re.DOTALL
        )[0]
        plans = json.loads(plans)
    except Exception:
        thought, plans = plan_output, []

    return {
        "input": input,
        "plan": plans,
        "result_steps": [],
        "chat_history": [],
        "next": "",
        "next_input": "",
        "response": "",
        "agent_outcome": "",
        "team_iterations": 0,
        "recursion_limit": 100,
        "current_step": 0,
    }

def run_orchestrator(state: StageExecute) -> StageExecute:
    valid_workers = {"content ordering", "text structuring", "surface realization", "FINISH"}

    # If previous step failed, return same state to retry same worker
    if state.get("agent_outcome", "") and state["agent_outcome"].strip() != "CORRECT":
        print(f"Inspector rejected output. Repeating step: {state['next']}")
        return state

    # Otherwise, ask orchestrator to choose the next worker
    response = orchestrator_agent.invoke({
        "input": state["input"],
        "result_steps": state["result_steps"],
        "plan": state["plan"]
    }).content
    print(f"Orchestrator response: {response}")

    lines = response.split("Worker:")
    if len(lines) < 2:
        raise ValueError("Orchestrator output missing expected format")

    thought = lines[0].replace("Thought:", "").strip()
    worker_name, worker_input = lines[1].split("Worker Input:")
    worker_name = worker_name.strip()

    if worker_name not in valid_workers:
        raise ValueError(f"Invalid worker selected by orchestrator: '{worker_name}'")
    
    if state.get("current_step", 0) > state.get("recursion_limit", 100):
        print("Reached max recursion limit. Forcing FINISH.")
        return {
            **state,
            "next": "FINISH",
            "next_input": "",
            "thought": "Recursion limit reached, finishing workflow.",
        }

    return {
        **state,
        "next": worker_name,
        "next_input": worker_input.strip(),
        "thought": thought,
        "agent_outcome": "",  # Reset inspector flag
        "current_step": state.get("current_step", 0) + 1
    }
    



def run_worker(state: StageExecute) -> StageExecute:
    worker_name = state["next"]
    agent = worker_agents.get(worker_name)
    if agent is None:
        raise ValueError(f"No agent found for worker '{worker_name}'")

    output = agent.invoke({"input": state["next_input"]}).content
    print(f"Worker '{worker_name}' output: {output}")
    result_step = {
        "agent": worker_name,
        "input": state["next_input"],
        "output": output,
        "thought": ""
    }

    # Return the updated state
    return {
        **state,
        "result_steps": state["result_steps"] + [result_step]
    }


def run_inspector(state: StageExecute) -> StageExecute:
    if not state["result_steps"]:
        raise ValueError("No result_steps found to inspect.")

    last_step = state["result_steps"][-1]

    feedback = inspector_agent.invoke({
        "input": last_step["input"],
        "result_steps": last_step["output"],
        "plan": state["plan"]
    }).content.strip()

    print(f"Inspector feedback: {feedback}")

    # Update state with new iteration count
    updated_state = {
        **state,
        "team_iterations": state.get("team_iterations", 0) + 1
    }

    # If the output is CORRECT, reset agent_outcome to allow next step
    if feedback.upper() == "CORRECT":
        updated_state["agent_outcome"] = "CORRECT"
    else:
        # Else mark for retry by preserving current worker
        updated_state["agent_outcome"] = feedback
        print(f"🔁 Repeating worker '{state['next']}' due to feedback.")

    return updated_state




def run_aggregator(state: StageExecute) -> dict:
    final_output = aggregator_agent.invoke({
        "input": state["input"],
        "plan": state["plan"],
        "result_steps": state["result_steps"]
    }).content
    return {"Final Answer": final_output.strip()}

# --- GRAPH ---
workflow = StateGraph(StageExecute)

workflow.add_node("planner", RunnableLambda(run_planner))
workflow.add_node("orchestrator", RunnableLambda(run_orchestrator))
workflow.add_node("worker", RunnableLambda(run_worker))
workflow.add_node("inspector", RunnableLambda(run_inspector))
workflow.add_node("aggregator", RunnableLambda(run_aggregator))

# --- EDGES ---
workflow.set_entry_point("planner")
workflow.add_edge("planner", "orchestrator")

def routing_function(state):
    nxt = state.get("next", "")
    if nxt in {"content ordering", "text structuring", "surface realization"}:
        return nxt
    elif nxt == "FINISH":
        return "FINISH"
    else:
        print("WARNING: Invalid or empty 'next' step. Repeating orchestrator.")
        return "orchestrator"  # Safe retry

workflow.add_conditional_edges(
    "orchestrator",
    routing_function,
    {
        "content ordering": "worker",
        "text structuring": "worker",
        "surface realization": "worker",
        "FINISH": "aggregator",
        "orchestrator": "orchestrator"  # Add this retry path
    }
)

workflow.add_edge("worker", "inspector")
workflow.add_edge("inspector", "orchestrator")
workflow.add_edge("aggregator", END)

graph_executor = workflow.compile()

# --- RUN ---
if __name__ == "__main__":
    example_input = """<page_title> 2009 NASCAR Camping World Truck Series </page_title> 
<section_title> Schedule </section_title> 
<table> 
    <cell> August 1 <col_header> Date </col_header> </cell> 
    <cell> Toyota Tundra 200 <col_header> Event </col_header> </cell> 
    <cell> Nashville Superspeedway <col_header> Venue </col_header> </cell> 
</table>"""

    input_data = {
        "input": example_input,
        "plan": [],
        "result_steps": [],
        "chat_history": [],
        "next": "",
        "next_input": "",
        "response": "",
        "agent_outcome": "",
        "team_iterations": 0,
        "recursion_limit": 100,
        "current_step": 0,
        "recursion_limit": 60,
    }

    final_result = graph_executor.invoke(input_data, config={"recursion_limit": input_data["recursion_limit"]})

    print(final_result["Final Answer"])


Planner output: Thought: The input data appears to be a structured representation of a sports summary page from the Rotowire website. We need to break down this data into smaller tasks that can be completed by each worker.

Plan:
```json
[
    {
        "step": "Content Ordering: Identify relevant sections and data points",
        "worker": "content ordering"
    },
    {
        "step": "Text Structuring: Organize selected data into a coherent summary",
        "worker": "text structuring"
    },
    {
        "step": "Surface Realization: Produce fluent natural language text from structured content",
        "worker": "surface realization"
    }
]
```
Thought: The first step involves identifying the relevant sections and data points in the input data. This includes extracting the page title, section titles, table headers, and cell values.

Plan:
```json
[
    {
        "step": "Content Ordering: Extract page title",
        "worker": "content ordering"
    },
    {
        "step": "