In [19]:
from typing import List, TypedDict, Dict
from pydantic import BaseModel
from langgraph.graph import StateGraph

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain_text_splitters import RecursiveCharacterTextSplitter

import requests

In [16]:
class Topic(BaseModel):
    question: str

class DecomposedQuery(BaseModel):
    topics: List[Topic]

In [17]:
class GraphState(TypedDict):
    # Inputs
    query: str
    documents: List[Dict]   # [{"doc_id": str, "text": str}]

    # Intermediate
    standalone_questions: List[str]
    chunks: List[Dict]      # [{"doc_id", "chunk_id", "text"}]

    # Outputs
    query_embeddings: List[List[float]]
    chunk_embeddings: List[List[float]]

In [20]:
class LMStudioBgeM3Dense:
    def __init__(self, base_url, model):
        self.url = f"{base_url}/embeddings"
        self.model = model

    def embed_documents(self, texts):
        r = requests.post(
            self.url,
            json={
                "model": self.model,
                "input": texts  # MUST be raw strings
            }
        )
        r.raise_for_status()
        return [d["embedding"] for d in r.json()["data"]]

    def embed_query(self, text):
        return self.embed_documents([text])[0]

In [21]:
llm = ChatOpenAI(
    base_url="http://127.0.0.1:1234/v1",
    api_key="not-needed",
    model="meta-llama-3.1-8b-instruct",
    temperature=0.0
)

embedding_model = LMStudioBgeM3Dense(
    base_url="http://127.0.0.1:1234/v1",
    model="text-embedding-bge-m3"
)

parser = PydanticOutputParser(pydantic_object=DecomposedQuery)

In [22]:
prompt = ChatPromptTemplate.from_messages([
    ("system", """
You are a query decomposition agent.

Rules:
- Split the query into distinct semantic topics
- Rewrite each topic as a standalone question
- One topic per question
- No overlap
- Do not add new information
- Output ONLY valid JSON
"""),
    ("human", "Query:\n{query}")
])

In [23]:
def decompose_query(state: GraphState) -> GraphState:
    chain = prompt | llm | parser
    result = chain.invoke({"query": state["query"]})

    return {
        **state,
        "standalone_questions": [t.question for t in result.topics]
    }


def chunk_documents(state: GraphState) -> GraphState:
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=800,
        chunk_overlap=100,
        separators=["\n\n", "\n", ". ", " "]
    )

    chunks = []

    for doc in state["documents"]:
        doc_id = doc["doc_id"]
        text = doc["text"]

        split_texts = splitter.split_text(text)

        for idx, chunk_text in enumerate(split_texts):
            chunks.append({
                "doc_id": doc_id,
                "chunk_id": idx,
                "text": chunk_text
            })

    return {
        **state,
        "chunks": chunks
    }


def embed_queries(state: GraphState) -> GraphState:
    # FIXED: Use embed_documents instead of encode
    embeddings = embedding_model.embed_documents(
        state["standalone_questions"]
    )

    return {
        **state,
        "query_embeddings": embeddings
    }


def embed_documents(state: GraphState) -> GraphState:
    texts = [c["text"] for c in state["chunks"]]

    # FIXED: Use embed_documents instead of encode
    embeddings = embedding_model.embed_documents(texts)

    return {
        **state,
        "chunk_embeddings": embeddings
    }

In [24]:
graph = StateGraph(GraphState)

graph.add_node("decompose_query", decompose_query)
graph.add_node("chunk_documents", chunk_documents)
graph.add_node("embed_queries", embed_queries)
graph.add_node("embed_documents", embed_documents)

graph.set_entry_point("decompose_query")
graph.add_edge("decompose_query", "chunk_documents")
graph.add_edge("chunk_documents", "embed_queries")
graph.add_edge("embed_queries", "embed_documents")

graph.set_finish_point("embed_documents")

app = graph.compile()

In [27]:
if __name__ == "__main__":
    initial_state = {
        "query": "Explain CRISPR, its ethical concerns, and its use in cancer treatment.",
        "documents": [
            # {
            #     "doc_id": "paper_1",
            #     "text": (
            #         "CRISPR is a powerful gene-editing technology that allows scientists "
            #         "to modify DNA with high precision. Ethical concerns include germline "
            #         "editing, unintended off-target effects, and unequal access. "
            #         "In cancer treatment, CRISPR is being explored for immunotherapy, "
            #         "including engineered T-cells."
            #     )
            # }
        ]
    }
    
    # Test decomposition first
    chain = prompt | llm | parser
    result = chain.invoke({"query": initial_state["query"]})

    print("\nStandalone Questions:")
    for q in result.topics:
        print("-", q.question)
    
    # Run full pipeline
    final_state = app.invoke(initial_state)
    print(f"\nNumber of chunks: {len(final_state['chunks'])}")
    print(f"Number of query embeddings: {len(final_state['query_embeddings'])}")
    print(f"Number of chunk embeddings: {len(final_state['chunk_embeddings'])}")


Standalone Questions:
- What is CRISPR?
- What are the ethical concerns surrounding CRISPR?
- How is CRISPR used in cancer treatment?

Number of chunks: 0
Number of query embeddings: 3
Number of chunk embeddings: 0
