In [None]:
import re
import sys
from pathlib import Path
import json
import time


from typing import Tuple, Dict, Any, List
from typing_extensions import TypedDict

from langgraph.types import interrupt
from langgraph.graph import StateGraph, START, END

sys.path.insert(0, str(Path.cwd().parent / "src"))
from assurhabitat_agents.model.llm_model_loading import llm_inference
from assurhabitat_agents.config.tool_config import VALIDATION_TOOLS


In [None]:
class ValidationReActState(TypedDict):
    images_path: list[str]
    history: list[str]  # L'historique des échanges (Thought, Action, Observation)
    last_action: str | None  # Le nom de l'outil à appeler (si applicable)
    last_arguments: dict | None  # Les arguments à passer à l'outil
    last_observation: str | None  # Le résultat de l'outil appelé
    
    parsed_declaration: dict  # Resultat de l'agent Declaration (Ne peut pas etre None, Le superviseur ne peut appeler cet agent que si cet etat est connu)

    # results from tools
    image_conformity: dict | None       # {"match": bool, "raw_output": str}
    guarantee_report: dict | None       # {"is_garanteed": bool, "guarantee": {...}}

In [None]:
def format_prompt(state: ValidationReActState, tools) -> str:
    
    HISTORY_KEEP = 10
    history = state.get("history", [])[-HISTORY_KEEP:]

    # Show parsed_declaration and missing fields if available
    parsed = state.get("parsed_declaration")
    conformity = state.get("image_conformity")
    guarantee = state.get("guarantee_report")
    images = state.get("images_path", [])

    # Build actions block
    actions_block = "\n".join(f"- {a}" for a in tools) if tools else "- (no tools available)"

    parts = [
        "You are the Validation Agent for AssurHabitat. Decide the next step: either",
        "1) call a tool (Action) OR 2) give the final answer (Réponse).",
        "",
        "Available tools:",
        actions_block,
        "",
        "Rules:",
        "- If you call a tool, use a single line: Action: TOOL_NAME",
        "- If arguments are needed, write: Arguments: then either a JSON object or key=value lines",
        "- If you return the final reply to the user, write: Réponse: <text>",
        "",
        "Decision rules:",
        "- If image_conformity is None: you MUST call CheckConformity first.",
        "- If guarantee_report is None but image_conformity.match == True: then call CheckGuarantee.",
        "- If both are completed: produce a final 'Réponse:' for the supervisor.",
        "",
        "Context summary:",
    ]

    if history:
        parts.append("Recent history:")
        parts.append("\n".join(history))
    if parsed:
        # pretty print the parsed_declaration small snippet
        try:
            pretty = json.dumps(parsed, ensure_ascii=False)
        except Exception:
            pretty = str(parsed)
        parts.append("Current parsed_declaration JSON: (you can find the sinistre type inside for CheckConformity)")
        parts.append(pretty)
    if conformity:
        parts.append("Conformity: " + json.dumps(conformity, ensure_ascii=False))
        
    parts.append("Images available in the state (use them for CheckConformity):")
    parts.append(json.dumps(images, ensure_ascii=False))
    
    if guarantee:
        parts.append("Guarantee: " + json.dumps(guarantee, ensure_ascii=False))

    parts.append("")
    parts.append("Now propose the next single Thought + Action (or final Réponse).")
    # join and return
    return "\n".join(parts)


In [None]:
def parse_output(output: str) -> Tuple[str, Any, Any]:
    text = output.strip()

    # Try to find an "Action:" line (match up to end-of-line, non-greedy)
    m_action = re.search(r"(?mi)^Action:\s*(?P<tool>[^\n\r]+)", text)
    m_args = re.search(r"(?mi)^Arguments:\s*(?P<args>[\s\S]+)$", text)  # capture until string end

    # If action present, parse args if any
    if m_action:
        tool_name = m_action.group("tool").strip()
        tool_args = {}

        if m_args:
            raw_args = m_args.group("args").strip()
            # Try JSON first
            # after finding raw_args:
            # cut off if raw_args contains "Observation" or "LLM output" (heuristic)
            cut_tokens = ["Observation from", "LLM output", "Action:", "Thought:"]
            for t in cut_tokens:
                idx = raw_args.find(t)
                if idx != -1:
                    raw_args = raw_args[:idx].strip()
                    break
            try:
                parsed = json.loads(raw_args)
                if isinstance(parsed, dict):
                    tool_args = parsed
                else:
                    tool_args = {"raw": parsed}
            except Exception:
                # Fallback: parse key=value lines
                lines = [l.strip() for l in raw_args.splitlines() if l.strip()]
                kv = {}
                for line in lines:
                    # accept "key = value" or "key=value"
                    m_kv = re.match(r"^\s*([^=]+?)\s*=\s*(.+)$", line)
                    if m_kv:
                        key = m_kv.group(1).strip()
                        val = m_kv.group(2).strip()
                        # try to interpret JSON value (numbers, lists, etc.)
                        try:
                            val_parsed = json.loads(val)
                        except Exception:
                            val_parsed = val
                        kv[key] = val_parsed
                    else:
                        # can't parse line -> keep raw under a list
                        kv.setdefault("_raw_lines", []).append(line)
                tool_args = kv if kv else {"raw": raw_args}

        return ("action", tool_name, tool_args)

    # If there's a "Réponse:" or "Answer:" line, treat as final answer
    m_answer = re.search(r"(?mi)^(Réponse|Answer):\s*(?P<ans>[\s\S]+)$", text)
    if m_answer:
        return ("answer", m_answer.group("ans").strip(), None)

    # Try to parse JSON directly as answer/action
    try:
        j = json.loads(text)
        if isinstance(j, dict):
            # if dict contains action key, map to action
            if "action" in j:
                return ("action", j.get("action"), j.get("args", {}))
            if "answer" in j:
                return ("answer", j.get("answer"), None)
    except Exception:
        pass

    # otherwise fallback to thought
    return ("thought", text, None)

In [None]:
tools = VALIDATION_TOOLS
tool_names = list(VALIDATION_TOOLS.keys())

def node_thought_action(state: DeclarationReActState) -> DeclarationReActState:

    prompt = format_prompt(state, tool_names)
    output = llm_inference(prompt)

    # parse_output must return a tuple like ("action", tool_name, tool_args)
    # or ("answer", answer_text) or ("thought", thought_text)
    step_type, *content = parse_output(output)

    # Append the raw LLM output to history for traceability
    state.setdefault("history", [])
    state["history"].append(f"LLM output: {output}")

    if step_type == "action":
        tool_name, tool_args = content
        # store next action and its arguments
        state["last_action"] = tool_name
        state["last_arguments"] = tool_args or {}
        # keep history friendly: record the action intention
        state["history"].append(f"Action: call tool: {tool_name} with args: {tool_args}")
    elif step_type == "answer":
        # final textual answer produced by the LLM
        state["last_action"] = None
        state["last_arguments"] = None
        state["last_observation"] = None
        state["history"].append(f"Answer: {content[0]}")
    else:
        # Thought only: no action requested, we keep loop running
        state["history"].append(f"Thought: {content[0] if content else ''}")
    return state

def node_tool_execution(state: DeclarationReActState) -> DeclarationReActState:
    """
    Execute the tool stored in state['last_action'] with state['last_arguments'].
    Update state['last_observation'], state['history'], and structured fields:
      - state['image_conformity']
      - state['guarantee_report']
    """
    tool_name = state.get("last_action")
    tool_args = state.get("last_arguments") or {}

    # nothing to execute
    if not tool_name:
        state.setdefault("history", []).append("No action to execute.")
        return state

    # call the tool if available
    if tool_name in TOOLS:
        try:
            observation = TOOLS[tool_name](**tool_args)
        except Exception as e:
            observation = f"Error during tool {tool_name}: {e}"
    else:
        observation = f"Error: Unknown tool {tool_name}"

    # store observation and history
    state["last_observation"] = str(observation)
    state.setdefault("history", []).append(f"Observation from {tool_name}: {state['last_observation']}")

    if tool_name == "CheckConformity":
        if isinstance(observation, dict):
            state['image_conformity'] = observation
        else:
            state["history"].append("check_conformity failed.")

    # ---------- CASE 2: verify_completeness tool (explicit call) ----------
    elif tool_name == "CheckGuarantee":
        if isinstance(observation, dict):
            state['guarantee_report'] = observation
        else:
            state["history"].append("CheckGuarantee failed.")

    # reset action so next Thought node computes next step
    state["last_action"] = None
    state["last_arguments"] = None

    return state

In [None]:
def build_graph():
    graph_builder = StateGraph(DeclarationReActState)
    graph_builder.add_node("thought", node_thought_action)
    graph_builder.add_node("action", node_tool_execution)

    graph_builder.add_edge(START, "thought")

    def decide_from_thought(runtime_state: DeclarationReActState):
            if runtime_state.get("guarantee_report") and runtime_state.get("image_conformity"):
                return END
            if runtime_state.get("last_action"):
                return "action"
            return "thought"

    graph_builder.add_conditional_edges("thought", decide_from_thought)
    graph_builder.add_edge("action", "thought")
    return graph_builder.compile()

In [None]:

def run_graph(graph, initial_state: ValidationReActState, max_steps: int = 10):
    """
    Generic runner for the compiled graph.
    - graph: result of build_graph(...). It must provide a `run_once(state)` or we emulate node execution.
    If your StateGraph API differs, adapt accordingly.
    """
    state = initial_state
    step = 0

    # Pretty print function
    def print_new_history(prev_len):
        history = state.get("history", [])
        for line in history[prev_len:]:
            print(line)
        return len(history)

    prev_history_len = 0
    while step < max_steps:
        step += 1
        state = node_thought_action(state)
        prev_history_len = print_new_history(prev_history_len)

        if state.get("guarantee_report") and state.get("image_conformity"):
            break

        if state.get("last_action"):
            state = node_tool_execution(state)
            prev_history_len = print_new_history(prev_history_len)
            # continue loop, next iteration Thought will run again
        else:
            # if no action and not complete, allow loop to continue (LLM might set action next)
            # small sleep to avoid busy loop in notebook (optional)
            time.sleep(0.01)
            continue

    # final
    print("\n--- FINAL STATE ---")
    print("guarantee_report:", state.get("guarantee_report"))
    print("image_conformity:", state.get("image_conformity"))
    return state

In [None]:
initial_state = {
    "images_path": [],
    "history": [],
    "last_action": None,
    "last_arguments": None,
    "last_observation": None,
    "parsed_declaration": {'sinistre_type': 'vol_vandalisme', 
                           'sinistre_confidence': 0.99, 
                           'sinistre_explain': 'cambriolage via vélux, appareils électroniques volés', 
                           'candidates': [{'type': 'vol_vandalisme', 'score': 0.99}], 
                           'extracted': {'date_sinistre': '2024-06-13', 
                                         'lieu': 'chambre', 
                                         'description': 'cambriolage via vélux, appareils électroniques volés', 
                                         'photos': [], 
                                         'biens_impactes': ['appareils électroniques'], 
                                         'police_report_number': '123456789'}}
    "image_conformity": None,
    "guarantee_report": None
}

graph = build_graph()

final_state = run_graph(graph, initial_state, max_steps=10)