In [4]:
from langgraph.graph import StateGraph
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain_text_splitters import RecursiveCharacterTextSplitter

from typing import List, TypedDict, Dict
from pydantic import BaseModel

In [5]:
class GraphState(TypedDict):
    # Inputs
    query: str
    documents: List[Dict]

    # Intermediate
    standalone_questions: List[str]
    chunks: List[Dict]
    query_embeddings: List[List[float]]
    chunk_embeddings: List[List[float]]

    # Retriever
    filters: Dict
    matches: List[Dict]   # final ranked results

In [6]:
class RetrieverFilter(BaseModel):
    must: Dict[str, str] | None = None
    should: Dict[str, str] | None = None
    top_k: int = 5

In [None]:
filter_parser = PydanticOutputParser(pydantic_object=RetrieverFilter)

filter_prompt = ChatPromptTemplate.from_messages([
    ("system", """
Extract payload filters for vector retrieval.

Rules:
- Only infer filters if clearly implied
- Use metadata fields (only): doc_type, category, year_enacted, jurisdiction, issuing_authority, status, domain
- Do NOT hallucinate values
- Output valid JSON only
"""),
    ("human", "{query}")
])

In [8]:
def generate_filters(state: GraphState) -> GraphState:
    chain = filter_prompt | llm | filter_parser
    result = chain.invoke({"query": state["query"]})

    return {
        **state,
        "filters": result.dict()
    }

In [9]:
from qdrant_client import QdrantClient, models

qdrant = QdrantClient(url="http://localhost:6333")
COLLECTION = "md_bge_m3_dense"

In [10]:
def retrieve_and_rank(state: GraphState) -> GraphState:
    results = []

    for q_emb in state["query_embeddings"]:
        search_filter = None

        if state["filters"].get("must"):
            search_filter = models.Filter(
                must=[
                    models.FieldCondition(
                        key=k,
                        match=models.MatchValue(value=v)
                    )
                    for k, v in state["filters"]["must"].items()
                ]
            )

        hits = qdrant.search(
            collection_name=COLLECTION,
            query_vector=q_emb,
            limit=state["filters"].get("top_k", 5),
            query_filter=search_filter,
            with_payload=True
        )

        for h in hits:
            results.append({
                "id": h.id,
                "score": h.score,
                "payload": h.payload
            })

    # final ranking (global)
    results = sorted(results, key=lambda x: x["score"], reverse=True)

    return {
        **state,
        "matches": results
    }

In [11]:
graph.add_node("generate_filters", generate_filters)
graph.add_node("retrieve_and_rank", retrieve_and_rank)

graph.add_edge("embed_documents", "generate_filters")
graph.add_edge("generate_filters", "retrieve_and_rank")

graph.set_finish_point("retrieve_and_rank")

NameError: name 'graph' is not defined