In [None]:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.retrievers import MultiQueryRetriever
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain_community.vectorstores import FAISS
from typing import List

llm = ChatOpenAI(model="gpt-4o-mini")
embeddings = OpenAIEmbeddings()

# ---------------- Step 1: Query Classifier ---------------- #
classifier_prompt = ChatPromptTemplate.from_template("""
You are a query classifier.
Classify the user query into one of: 
- "ambiguous"
- "multi_intent"
- "normal"

Query: {query}
""")

classifier = LLMChain(llm=llm, prompt=classifier_prompt)

# ---------------- Step 2a: MultiQueryRetriever ---------------- #
def get_mqr_retriever(vectorstore):
    return MultiQueryRetriever.from_llm(
        retriever=vectorstore.as_retriever(), llm=llm
    )

# ---------------- Step 2b: Multi-Intent Decomposer ---------------- #
decomposer_prompt = ChatPromptTemplate.from_template("""
Split the following query into multiple independent sub-queries if needed.
Return them as a list, one per line.

Query: {query}
""")
decomposer = LLMChain(llm=llm, prompt=decomposer_prompt)

# ---------------- Pipeline Function ---------------- #
def retrieve_with_flex(query: str, vectorstore: FAISS, k=3) -> List[str]:
    # Step 1: Classify
    qtype = classifier.run({"query": query}).strip().lower()

    docs = []

    if "ambiguous" in qtype:
        # use MQR
        mqr = get_mqr_retriever(vectorstore)
        docs = mqr.get_relevant_documents(query)

    elif "multi_intent" in qtype:
        # use decomposition
        sub_queries = decomposer.run({"query": query}).split("\n")
        for sq in sub_queries:
            if sq.strip():
                docs.extend(vectorstore.similarity_search(sq, k=k))

    else:
        # normal single query
        docs = vectorstore.similarity_search(query, k=k)

    # Step 4: aggregate (dedup by content)
    seen = set()
    final = []
    for d in docs:
        if d.page_content not in seen:
            seen.add(d.page_content)
            final.append(d)

    return final


In [None]:
from langchain_openai import ChatOpenAI
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.retrievers import BaseRetriever

llm = ChatOpenAI(model='gpt-4o-mini', api_key="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")

# pass a fake retriever (not used if you only call _generate_queries)
class DummyRetriever(BaseRetriever):
    def get_relevant_documents(self, query): return []
    async def aget_relevant_documents(self, query): return []

mqr = MultiQueryRetriever.from_llm(retriever=DummyRetriever(), llm=llm)

queries = mqr._generate_queries("Tell me about ML and DL")
print(queries)

ValidationError: 1 validation error for MultiQueryRetriever
retriever
  Input should be a valid dictionary or instance of BaseRetriever [type=model_type, input_value=<__main__.DummyRetriever ...t at 0x00000277E9040050>, input_type=DummyRetriever]
    For further information visit https://errors.pydantic.dev/2.11/v/model_type