In [1]:
from __future__ import annotations

import dotenv
dotenv.load_dotenv()

import os

import json
import re
import uuid
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Generator, Iterable

import pickle

from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage, HumanMessage, BaseMessage
from langchain.agents import Tool
from langgraph.graph import END, StateGraph

In [2]:
# ---------------------------------------------------------------------------
#  CONFIG 
# ---------------------------------------------------------------------------

DEBUG = False            # planner + high‑level debug
DEBUG_VERBOSE_RAW = False  # print raw tool outputs (can be large)

logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)

In [3]:
# ---------------------------------------------------------------------------
# STATE
# ---------------------------------------------------------------------------

@dataclass
class CompareState:
    """Holds everything the graph knows so far."""

    user_prompt: str  # original complex prompt from the user
    conversation: List[Dict[str, str]] = field(default_factory=list)

    # Working memory --------------------------------------------------------
    claims: List[str] = field(default_factory=list)  # extracted from doc‑A
    claim_status: Dict[str, str] = field(default_factory=dict)  # claim -> status
    claim_attempts: Dict[str, int] = field(default_factory=dict)  # claim -> tries
    new_drivers: List[str] = field(default_factory=list)  # extracted from doc‑B

    # Control ---------------------------------------------------------------
    next_action: Optional[Dict[str, str]] = None  # {tool, query, result?}
    finished: bool = False  # when True the planner will route to END
    steps: int = 0  # safeguard counter to avoid infinite loops

    # Debug trace -----------------------------------------------------------
    trace: List[Dict[str, Any]] = field(default_factory=list)

# ---------------------------------------------------------------------------
# TOOLS
# ---------------------------------------------------------------------------

def _wrap_query_engine(name: str, engine):
    """Turns a LlamaIndex QueryEngineTool into a LangChain Tool."""

    def _run(query: str) -> str:  # synchronous wrapper; adapt if async
        response = engine.query(query)
        return str(response)

    return Tool(name=name, func=_run, description=f"Queries {name}")


with open(r"C:\Users\info\Desktop\harlus\server\data\AAPL\Investment_Case_AMAT_pdf\tools\doc_search\tool.pkl", "rb") as file:
    doc_tool_thesis = pickle.load(file)
with open(r"C:\Users\info\Desktop\harlus\server\data\AAPL\AMAT_Q3_24_Earnings_Release_pdf\tools\doc_search\tool.pkl", "rb") as file:
    doc_tool_evidence = pickle.load(file)

query_engine_a = doc_tool_thesis.tool.query_engine
query_engine_b = doc_tool_evidence.tool.query_engine

if query_engine_a is Ellipsis or query_engine_b is Ellipsis:
    raise ValueError("Please supply `query_engine_a` and `query_engine_b`.")

predictions_tool = _wrap_query_engine("predictions_doc", query_engine_a)
newinfo_tool = _wrap_query_engine("new_info_doc", query_engine_b)

TOOLS_BY_NAME = {
    predictions_tool.name: predictions_tool,
    newinfo_tool.name: newinfo_tool,
}


# ---------------------------------------------------------------------------
# LLM
# ---------------------------------------------------------------------------

llm = ChatOpenAI(
    model_name="gpt-4o-mini", 
    temperature=0.0,
    api_key=os.getenv("OPENAI_API_KEY"),
)

PLAN_SYS_PROMPT = (
    "You are the planning module of an autonomous agent that compares two "
    "financial documents. Decide the *single* next action. Return **ONLY** a "
    "JSON dict with keys: 'tool' (predictions_doc | new_info_doc | none) and "
    "'query'. Do not wrap in ``` fences."
)

plan_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", PLAN_SYS_PROMPT),
        MessagesPlaceholder(variable_name="memory"),
    ]
)

# ---------------------------------------------------------------------------
# HELPERS
# ---------------------------------------------------------------------------


def _extract_json(text: str) -> Optional[Any]:
    try:
        return json.loads(text)
    except Exception:
        pass
    text = re.sub(r"^```[a-zA-Z0-9]*|```$", "", text.strip(), flags=re.MULTILINE)
    m = re.search(r"(\{[\s\S]*?\}|\[[\s\S]*?\])", text)
    if m:
        try:
            return json.loads(m.group(1))
        except Exception:
            return None
    return None


def _normalise_status(raw: str) -> Optional[str]:
    raw = raw.lower()
    if "confirm" in raw:
        return "confirmed"
    if any(k in raw for k in ("contradict", "refut", "disagree")):
        return "contradicted"
    if any(k in raw for k in ("not_discuss", "not discuss", "irreleva", "no evidence")):
        return "not_discussed"
    return None


def _memory_as_messages(state: CompareState) -> List[BaseMessage]:
    msgs: List[BaseMessage] = []
    for it in state.conversation:
        line = (
            f"TOOL {it['tool']} ⇒ {it['query'][:60].strip()}… :: "
            f"{it['result'][:120].strip()}"
        )
        msgs.append(SystemMessage(content=line))
    if not msgs:
        msgs.append(SystemMessage(content="⟨no prior actions⟩"))
    return msgs

# ---------------------------------------------------------------------------
# GRAPH NODES
# ---------------------------------------------------------------------------

def _llm_plan(state: CompareState) -> Optional[Dict[str, str]]:
    memory_msgs = _memory_as_messages(state)
    msgs = plan_prompt.format_messages(memory=memory_msgs)
    resp = llm(msgs)
    payload = _extract_json(resp.content)
    if isinstance(payload, dict):
        tool = payload.get("tool")
        query = payload.get("query")
        if tool in TOOLS_BY_NAME or tool == "none":
            return {"tool": tool, "query": query}
    return None


def _fallback_plan(state: CompareState) -> Dict[str, str]:
    if not state.claims:
        return {"tool": predictions_tool.name, "query": (
            "Return ONLY a JSON array of distinct claims/hypotheses from the document."
        )}
    for cl in state.claims:
        if cl not in state.claim_status and state.claim_attempts.get(cl, 0) < 3:
            return {"tool": newinfo_tool.name, "query": (
                f"Does the document CONFIRM, CONTRADICT, or NOT DISCUSS this claim?"
                f" Respond with JSON dict {{'status':...,'justification':...}} Claim: {cl}"
            )}
    if not state.new_drivers:
        joined = "\n".join(state.claims)
        return {"tool": newinfo_tool.name, "query": (
            f"List NEW drivers in this doc not in list:\n{joined}\nReturn JSON array only."
        )}
    return {"tool": "none", "query": ""}


def planner_node(state: CompareState) -> CompareState:
    if state.steps >= 60:
        state.finished = True
        return state
    act = _llm_plan(state) or _fallback_plan(state)
    if act["tool"] == "none":
        state.finished = True
        state.next_action = None
    else:
        state.next_action = act
    return state


def run_tool_node(state: CompareState) -> CompareState:
    spec = state.next_action
    if not spec or spec.get("tool") == "none":
        return state
    tool = TOOLS_BY_NAME[spec["tool"]]
    res = tool.run(spec["query"])
    state.conversation.append({"tool": tool.name, "query": spec["query"], "result": res})
    state.conversation = state.conversation[-6:]
    state.trace.append({"step": state.steps, "tool": tool.name,
                        "query": spec["query"], "raw": res})
    state.next_action["result"] = res
    return state


def parse_update_node(state: CompareState) -> CompareState:
    spec = state.next_action
    if not spec or "result" not in spec:
        return state
    payload = _extract_json(spec["result"])
    if spec["tool"] == predictions_tool.name and isinstance(payload, list):
        state.claims = [str(c).strip() for c in payload if str(c).strip()]
    elif spec["tool"] == newinfo_tool.name:
        if isinstance(payload, dict):
            status = _normalise_status(str(payload.get("status", "")))
            cl = spec["query"].split("Claim:")[-1].strip()
            if status:
                state.claim_status[cl] = status
            else:
                state.claim_attempts[cl] = state.claim_attempts.get(cl, 0) + 1
                if state.claim_attempts[cl] >= 3:
                    state.claim_status[cl] = "not_discussed"
        elif isinstance(payload, list):
            state.new_drivers = [str(x).strip() for x in payload if str(x).strip()]
    # Clear next_action and bump step
    state.next_action = None
    state.steps += 1
    return state

# ---------------------------------------------------------------------------
# GRAPH DEFINITION
# ---------------------------------------------------------------------------

graph = StateGraph(CompareState)

graph.add_node("plan", planner_node)
graph.add_node("run_tool", run_tool_node)
graph.add_node("parse", parse_update_node)

graph.set_entry_point("plan")

# plan → run_tool or END
graph.add_conditional_edges("plan", lambda s: "run_tool" if not s.finished else END)
# run_tool → parse → plan
graph.add_edge("run_tool", "parse")
graph.add_edge("parse", "plan")

compare_chain = graph.compile()

# ---------------------------------------------------------------------------
# HELPER TO GENERATE FINAL ANSWER
# ---------------------------------------------------------------------------

def _md_bullets(lst: List[str]) -> str:
    return "\n".join(f"- {x}" for x in lst) if lst else "_None_"

def format_final_answer(state: CompareState) -> str:
    confirmed = [c for c, st in state.claim_status.items() if st == "confirmed"]
    contradicted = [c for c, st in state.claim_status.items() if st == "contradicted"]
    undiscussed = [c for c, st in state.claim_status.items() if st == "not_discussed"]

    return (
        "## Comparison Result\n\n"  # H2 header
        "### 1. Claims confirmed by new information\n" + _md_bullets(confirmed) + "\n\n"
        "### 2. Claims contradicted by new information\n" + _md_bullets(contradicted) + "\n\n"
        "### 3. Claims not discussed in new information\n" + _md_bullets(undiscussed) + "\n\n"
        "### 4. New value drivers introduced in new information\n" + _md_bullets(state.new_drivers)
    )


def compare_documents(
    user_prompt: str,
    *,
    recursion_limit: int = 100,
    debug: bool | None = None,
) -> str:
    """High‑level helper. Returns *only* the final formatted Markdown answer.

    Set ``debug=True`` to enable stdout tracing (planner + tool calls)."""

    global DEBUG
    if debug is not None:
        DEBUG = debug

    initial_state = CompareState(user_prompt=user_prompt)
    final_state = compare_chain.invoke(
        initial_state,
        config={"recursion_limit": recursion_limit},
    )
    return format_final_answer(final_state)


# ---- 🆕 STREAMING / INTROSPECTION API -------------------------------------

def compare_documents_stream(
    user_prompt: str,
    *,
    recursion_limit: int = 100,
) -> Iterable[CompareState]:
    """Generator that yields **every** intermediate state, so you can watch
    progress or analyse after‑the‑fact.

    Example::

        for st in compare_documents_stream(prompt):
            print(st.steps, st.next_action)  # or log whatever
    """

    initial_state = CompareState(user_prompt=user_prompt)
    for chunk in compare_chain.stream(
        initial_state,
        config={"recursion_limit": recursion_limit},
        stream_mode="values",  # yields plain state values
    ):
        yield chunk

Note: NumExpr detected 20 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
NumExpr defaulting to 16 threads.
  from .autonotebook import tqdm as notebook_tqdm
  llm = ChatOpenAI(


In [4]:
# Basic one‑shot -----------------------------------------------------
prompt = (
    "Compare document A (predictions) with document B (earnings report) "
    "and classify predictions as confirmed/contradicted/undiscussed, plus "
    "list new drivers."
)

print("\n=== FINAL ANSWER (quiet run) ===\n")
print(compare_documents(prompt))

# Verbose streaming --------------------------------------------------
print("\n=== STREAMED STEPS ===\n")
for st in compare_documents_stream(prompt):
    print(f"step {st.steps:02d}", st.trace[-1] if st.trace else "<planning>")


=== FINAL ANSWER (quiet run) ===



  resp = llm(msgs)
HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


AttributeError: 'AddableValuesDict' object has no attribute 'claim_status'

In [None]:
import os
os.getenv("OPENAI_API_KEY")
