# Query Routing in RAG

Query routing in Retrieval-Augmented Generation (RAG) applications involves directing user queries to the most relevant data sources or modules to generate the most accurate and contextually appropriate responses. This process ensures that each query is handled by the optimal component within a complex system, which may include multiple databases, language models, or specialized APIs. Effective query routing improves the efficiency and accuracy of information retrieval by leveraging the strengths of various sources and distributing the computational load. It involves analyzing the query to understand its intent and context, then dynamically selecting the best route for processing. This could mean querying a specific knowledge base for factual answers, using a conversational model for dialogue-based queries, or accessing a specialized API for niche information. By strategically routing queries, RAG applications can provide more precise, relevant, and timely responses, enhancing the overall user experience and making the system more robust and scalable.

In [74]:
!pip install langchain-core langchain langchain-community langchain-openai pypdf langchainhub chromadb

In [75]:
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain import hub
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

# Logical Routing in RAG

In logical routing, we let the LLM (Large Language Model) choose the best route from a set of predefined options. To implement this, we first create a router using a Pydantic model and define several main routes. In the QueryRouter model, we have two fields: `datasource` and `question`. The `datasource` field tells where the query should go, with the values representing our earlier created vector stores or "general" as a fallback to send the query directly to the LLM. The `question` field holds the user's query to ensure it is directed correctly. This method makes the routing process simpler and more accurate by using the LLM's abilities.

In [76]:
api_key = userdata.get("OPENAI")
def generate_vectorstores(file, dir):
    loader = PyPDFLoader(file)
    documents = loader.load()

    text_splitter  = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=20)
    text_chunks = text_splitter.split_documents(documents)
    vectorstore = Chroma.from_documents(documents=text_chunks,
                                        embedding=OpenAIEmbeddings(api_key=api_key),
                                        persist_directory=dir)
    vectorstore.persist()
    return vectorstore


# Create a vectorstore to answer questions about the U.S.A
vectorstore_usa = generate_vectorstores("../docs/us_states_info.pdf","vector_stores/vectorstore_usa")

# Create a vectorstore to answer questions about Canada
vectorstore_canada = generate_vectorstores("../docs/canada_provinces_info.pdf","vector_stores/vectorstore_canada")

In [7]:
retriever_usa = vectorstore_usa.as_retriever(search_kwargs={"k": 3})
retriever_canada = vectorstore_canada.as_retriever(search_kwargs={"k": 3})

In [8]:
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from typing import Literal

class QueryRouter(BaseModel):
    datasource: Literal["usa", "canada", "general"] = Field(..., description="Given a user question choose which datasource would be most relevant for answering their question")
    question: str = Field(..., description="User question to be routed to the appropriate datasource")

chat = ChatOpenAI(model="gpt-4",api_key=api_key, temperature=0)
structured_chat = chat.with_structured_output(QueryRouter)

router_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are an expert router that can direct user queries to the appropriate datasource. Route the following user question to the appropriate datasource.\nIf it is a general question not related to the provided datasources, route it to the general datasource.\n"),
        ("user", "{question}")
    ]
)

In [9]:
router = (
      router_prompt
    | structured_chat
)

In [10]:
question = "How is Toronto like?"
result = router.invoke({"question": question})
print(result)

datasource='canada' question='How is Toronto like?'


In [34]:
qa_prompt = hub.pull('rlm/rag-prompt')

In [35]:
print(result.question)

How is Toronto like?


In [36]:
print(qa_prompt.input_variables)

['context', 'question']


In [54]:
def choose_route(result):
  chat_model = ChatOpenAI(model="gpt-4",api_key=api_key, temperature=0)
  if result.datasource == "usa":
    usa_chain = (
        {'context': retriever_usa, 'question':RunnablePassthrough()}
        | qa_prompt
        | chat_model
        | StrOutputParser()
    )
    return usa_chain.invoke(result.question)

  elif result.datasource == "canada":
    canada_chain = (
        {'context': retriever_canada, 'question':RunnablePassthrough()}
        | qa_prompt
        | chat_model
        | StrOutputParser()
    )
    return canada_chain.invoke(result.question)
  else:
    general_chain = chat_model | StrOutputParser()
    return general_chain.invoke(result.question)

In [55]:
logical_router_chain = router | RunnableLambda(choose_route)

In [60]:
question = "How is Seattle like?"
logical_router_chain.invoke(question)

'The provided context does not contain information about Seattle.'

# Semantic Routing in RAG

Sometimes, we maintain multiple indexes for different domains, allowing us to query specific subsets of these indexes based on the questions we receive. For instance, imagine we have one vector store index containing detailed information about U.S. states and another dedicated to Canadian provinces and territories. When presented with a question about a particular city, we first need to determine which state or province the question pertains to and then query the relevant documents from the appropriate index. This process of determining which index or subset of indexes to query based on the query's context and content is known as query routing. Query routing ensures that we efficiently retrieve the most relevant information by classifying and directing the query to the correct data source.

Semantic routing relies on the degree of semantic similarity between the user query and the predefined router prompts to determine the most appropriate path for processing the query. This method involves analyzing the meaning and context of the user's input and comparing it to a set of potential responses or actions defined by the router. By identifying the closest match in terms of semantic content, the system can select the optimal route for handling the query. Let's explore how to implement this approach for a RAG application.



In [None]:
from langchain_community.utils.math import cosine_similarity
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings

In [61]:
us_template = """You are a very good U.S regional expert. \
               You are great at answering questions about the U.S. \
               When you don't know the answer to a question simply say I don't know.

               Here is a question:
               {query}
               """

In [62]:
canada_template = """You are a very good Canada regional expert. \
               You are great at answering questions about Canada. \
               When you don't know the answer to a question simply say I don't know.

               Here is a question:
               {query}
               """

In [64]:
embeddings = OpenAIEmbeddings(api_key=api_key)
routes = [us_template, canada_template]
route_embeddings = embeddings.embed_documents(routes)
len(route_embeddings)

2

In [71]:
from langchain.utils.math import cosine_similarity
from langchain.prompts import PromptTemplate

def router(input):
    # Generate embeddings for the user query
    query_embedding = embeddings.embed_query(input['query'])
    # Getting similarity scores between the user query and the routes. This contains the similarity scores between the user query and each of the two routes.
    similarity = cosine_similarity([query_embedding], route_embeddings)[0]
    # Find the route that gives the maximum similarity score
    route_id = similarity.argmax()
    if route_id == 0:
        print(f"> Asking a question about U.S.A...\nQuestion: {input['query']}\nAnswer:")
    else:
        print(f"> Asking a question about Canada...\nQuestion: {input['query']}\nAnswer:")

    return PromptTemplate.from_template(routes[route_id])


In [73]:
chat_model = ChatOpenAI(model="gpt-4",api_key=api_key, temperature=0)
semantic_router_chain = (
    {'query': RunnablePassthrough()}
    | RunnableLambda(router)
    | chat_model
    | StrOutputParser()
)

semantic_router_chain.invoke("How is Seattle like?")

> Asking a question about U.S.A...
Question: How is Seattle like?
Answer:


"Seattle, located in the Pacific Northwest region of the U.S., is known for its lush evergreen forests and beautiful waterways. It's the largest city in the state of Washington and is home to a diverse population. The city has a reputation for being rainy, although it actually receives less annual rainfall than many other U.S. cities, but the rain is spread out over many days, leading to the perception of constant drizzle.\n\nSeattle is a hub for technology and innovation, with companies like Microsoft and Amazon headquartered there. It's also known for its strong music scene, particularly grunge music, with bands like Nirvana and Pearl Jam originating from the city. The city's landmarks include the futuristic Space Needle, Pike Place Market, and the stunning Chihuly Garden and Glass museum.\n\nThe city has a strong coffee culture, with Starbucks originating there, and numerous other local coffee shops. Seattle also has a vibrant food scene, with a focus on local and sustainable ingred