In [None]:
x = 5

: 

In [None]:
# graph_builder.py
from __future__ import annotations

from typing import TypedDict, Literal, Optional, List, Dict, Any
from typing_extensions import Annotated
import operator
import logging

from langgraph.graph import StateGraph, START, END

In [None]:
from RAG_integrated.session_context import SessionContext
from RAG_integrated.rag_functions import (
    retrieve_hybrid_stm, 
    retrieve_hybrid_hcm, 
    retrieve_hybrid_ltm,
    rerank_with_mmr_and_recency, 
    insert_short_term,
)

In [None]:
# ---------- Logging ----------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

AllowedTopic = Literal["healthcare", "long-term", "short-term"]

# ---------- State ----------
class GraphState(TypedDict):
    # Inputs
    session: SessionContext
    input_text: str
    qa_type: Literal["question", "statement"]
    topics: List[AllowedTopic]
    candidates: Annotated[List[Dict[str, Any]], operator.add]
    final_chunks: List[Dict[str, Any]]
    # set True if insertion runs
    inserted: bool


# ---------- Nodes ----------

def retrieval_node_health(state: GraphState) -> Dict[str, Any]:
    session = state["session"]
    text = state["input_text"]
    logger.info("Called retrieval_node_health...")
    results = retrieve_hybrid_hcm(
        engine=session.db_engine,
        elderly_id=session.elderly_id,
        query=text
    )
    logger.info("retrieval_node_health successful! Retrieved %d results.", len(results))
    return {"candidates": results}



def retrieval_node_longterm(state: GraphState) -> Dict[str, Any]:
    session = state["session"]
    text = state["input_text"]
    logger.info("Called retrieval_node_longterm...")
    results = retrieve_hybrid_ltm(
        engine=session.db_engine,
        elderly_id=session.elderly_id,
        query=text
    )
    logger.info("retrieval_node_longterm successful! Retrieved %d results.", len(results))
    return {"candidates": results}



def retrieval_node_shortterm(state: GraphState) -> Dict[str, Any]:
    session = state["session"]
    text = state["input_text"]
    logger.info("Called retrieval_node_shortterm...")
    results = retrieve_hybrid_stm(
        engine=session.db_engine,
        elderly_id=session.elderly_id,
        query=text
    )
    logger.info("retrieval_node_shortterm successful! Retrieved %d results.", len(results))
    return {"candidates": results}



def reranker_node(state: GraphState) -> Dict[str, Any]:
    session = state["session"]
    text = state["input_text"]
    candidates = state.get("candidates", [])
    logger.info("Called reranker_node with %d candidates...", len(candidates))
    chunks = rerank_with_mmr_and_recency(
        query=text,
        candidates=candidates,
        cross_encoder=session.embedder
    )
    logger.info("reranker_node successful! Produced %d chunks.", len(chunks))
    return {"final_chunks": chunks}



def insertion_node(state: GraphState) -> Dict[str, Any]:
    session = state["session"]
    text = state["input_text"]
    logger.info("Called insertion_node...")
    insert_short_term(
        engine=session.db_engine,
        content=text,
        embedder=session.embedder
    )
    logger.info("insertion_node successful! Inserted new content.")
    return {"inserted": True}


# ---------- Routers ----------

def qa_router(state: GraphState) -> str:
    # Return a label used by add_conditional_edges
    qa = state["qa_type"]
    return "question" if qa == "question" else "statement"

def topics_router(state: GraphState) -> List[str]:
    # fan-out across selected topics
    # return a list of labels that map to nodes
    # ensures topics are valid and de-duplicated
    seen = set()
    selected: List[str] = []
    for t in state.get("topics", []):
        if t in ("healthcare", "long-term", "short-term") and t not in seen:
            selected.append(t)
            seen.add(t)
    # Always have at least one by your contract; if not, you could default/raise.
    return selected

def statement_fork_router(state: GraphState) -> List[str]:
    # For "statement", we run insertion AND retrieval in parallel.
    # We return two labels so the scheduler launches both branches.
    return ["do_insertion", "route_topics_for_statement"]


# ---------- Subgraph Builders ----------

def build_retrieval_subgraph(name_prefix: str = "") -> StateGraph[GraphState]:
    """
    Builds a mini-graph that:
      topics_router -> topic-specific retrieval nodes (in parallel) -> rerank
    """
    g = StateGraph(GraphState)
    # Nodes
    g.add_node(f"{name_prefix}topics_router", topics_router)  # router node (returns list)
    g.add_node(f"{name_prefix}retrieve_healthcare", retrieval_node_health)
    g.add_node(f"{name_prefix}retrieve_long_term", retrieval_node_longterm)
    g.add_node(f"{name_prefix}retrieve_short_term", retrieval_node_shortterm)
    g.add_node(f"{name_prefix}rerank", reranker_node)

    # Fan-out per topic
    g.add_conditional_edges(
        f"{name_prefix}topics_router",
        topics_router,
        {
            "healthcare": f"{name_prefix}retrieve_healthcare",
            "long-term": f"{name_prefix}retrieve_long_term",
            "short-term": f"{name_prefix}retrieve_short_term",
        },
    )

    # Fan-in to rerank
    g.add_edge(f"{name_prefix}retrieve_healthcare", f"{name_prefix}rerank")
    g.add_edge(f"{name_prefix}retrieve_long_term", f"{name_prefix}rerank")
    g.add_edge(f"{name_prefix}retrieve_short_term", f"{name_prefix}rerank")

    # Entrypoint and exit
    g.add_edge(START, f"{name_prefix}topics_router")
    g.add_edge(f"{name_prefix}rerank", END)

    return g


def build_insertion_subgraph(name_prefix: str = "") -> StateGraph[GraphState]:
    """
    Simple insertion node. If you later want topic-based insertion, copy the
    retrieval fan-out pattern here as well.
    """
    g = StateGraph(GraphState)
    g.add_node(f"{name_prefix}insert", insertion_node)
    g.add_edge(START, f"{name_prefix}insert")
    g.add_edge(f"{name_prefix}insert", END)
    return g


# ---------- Unified Graph ----------

def build_unified_graph() -> Any:
    """
    Unified graph:
      - If qa == 'question': run Retrieval DAG only.
      - If qa == 'statement': run Insertion AND Retrieval in parallel.
    """
    # Subgraphs
    retrieval = build_retrieval_subgraph("q_")      # question retrieval flow
    retrieval_stmt = build_retrieval_subgraph("s_") # statement retrieval flow (identical, separate namespace)
    insertion = build_insertion_subgraph("s_")      # statement insertion flow

    g = StateGraph(GraphState)

    # Mount subgraphs as nodes
    g.add_node("retrieval_question", retrieval.compile())
    g.add_node("retrieval_statement", retrieval_stmt.compile())
    g.add_node("insertion_statement", insertion.compile())

    # Routers
    g.add_node("qa_router", qa_router)
    g.add_node("statement_fork_router", statement_fork_router)

    # From START, decide question vs statement
    g.add_conditional_edges(
        START,
        qa_router,
        {
            "question": "retrieval_question",           # run retrieval-only subgraph
            "statement": "statement_fork_router",       # then fan to insertion + retrieval
        },
    )

    # For statement path, fork to two branches in parallel
    g.add_conditional_edges(
        "statement_fork_router",
        statement_fork_router,
        {
            "do_insertion": "insertion_statement",
            "route_topics_for_statement": "retrieval_statement",
        },
    )

    # Where do we end?
    # Both subgraphs individually end in END, so after both branches complete,
    # the unified graph will be done.
    # (No extra edges needed; the scheduler handles parallelism + completion.)

    compiled = g.compile()
    return compiled


In [None]:
# ---------- Example usage ----------

# def run_example(session: SessionContext):
#     graph = build_unified_graph()
#     # QUESTION example
#     out_q = graph.invoke({
#         "session": session,
#         "input_text": "What meds are currently prescribed?",
#         "qa_type": "question",
#         "topics": ["healthcare", "short-term"],
#         "candidates": [],
#         "final_chunks": [],
#         "inserted": False,
#     })
#     # STATEMENT example
#     out_s = graph.invoke({
#         "session": session,
#         "input_text": "Add note: medication taken at 9:00 PM.",
#         "qa_type": "statement",
#         "topics": ["short-term"],
#         "candidates": [],
#         "final_chunks": [],
#         "inserted": False,
#     })
#     return out_q, out_s