# RAG Routing

In this notebook there'll be covered the following **routing options**:

1. **Completion Routers** - LLM Completion Routers use an LLM completion call to return a single word that best describes the query from a list of word options provided in the prompt. This word is then used as part of an If/Else condition to control the application's flow.

2. **Function Calling Routers** - LLM Function Calling Routers leverage the function-calling ability of LLMs to pick a route to traverse. Routes are set up as functions with appropriate descriptions, and based on the query, the LLM returns the correct function to use.

3. **Semantic Routers** - Semantic Routers use embeddings and similarity searches to select the best route. Each route has associated example queries that are embedded and stored as vectors; the incoming query is embedded, and a similarity search determines the closest match.

4. **Zero Shot Classification Routers** - Zero Shot Classification Routers use a Zero-Shot Classification model to assign a label to a piece of text from a predefined set of labels. They can classify new examples from previously unseen classes, making them versatile for various queries.

5. **Language Classification Routers** - Language Classification Routers identify the language of the query and route it accordingly. They are useful for applications requiring multilingual parsing capabilities.

6. **Keyword Routers** - Keyword Routers select a route by matching keywords between the query and predefined route lists. They can be powered by LLMs or other keyword matching libraries.

7. **Logical Routers** - Logical Routers use logic checks against variables such as string lengths, file names, and value comparisons to handle query routing. They rely on existing and discrete variables rather than natural language understanding.

One day maybe I'll add some pretty graphics here ;)

In [89]:
# import modules
import os
from langchain_openai import AzureChatOpenAI
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import NLTKTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama
import json
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel
import torch
from sentence_transformers import SentenceTransformer

## 1. Completion Router

In [2]:
# design prompt

prompt = PromptTemplate(
    template="""

    You are a brilliant assistant who's exceptional in classification tasks.
    Your main task is to classify user's query below as either being about `Coffee`, `Tee`, `Soft Drinks`, `Alcoholic Drinks` or `Other`.

    Do not respond with more than one word.

    <user query>
    {user_query}
    </user query>

    Classification:
    """,
    input_variables=["user_query"],
)

In [16]:
user_query = "Where can I find kenyan K7 or Ruiru 11 sorts?" # K7 and Ruiru 11 are popular kenyan coffee sorts

In [17]:
# complete router

llama = ChatOllama(model="llama3", temperature=0)

completion_route_chain = prompt | llama | StrOutputParser()

input_data = {
    "user_query": user_query
}

route = completion_route_chain.invoke(input=input_data)
print(f"ROUTE: {route}")

ROUTE: Coffee


## 2. Function Calling Router

In [19]:
# To be added soon

## 3. Semantic Router

In [34]:
# use either semantic_router library or create a custom Route class from the one below

emb_model = "sentence-transformers/all-MiniLM-L6-v2"

class Route:
    def __init__(self, name, utterances, embedding_model_name=emb_model):
        self.name = name
        self.utterances = utterances
        self.embedding_model_name = embedding_model_name
        self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
        self.model = AutoModel.from_pretrained(embedding_model_name)
        self.embeddings = self._embed_utterances(utterances)

    def _embed_utterances(self, utterances):
        # tokenize utterances
        tokens = self.tokenizer(utterances, padding=True, truncation=True, return_tensors="pt")
        # get embeddings
        with torch.no_grad():
            embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy()
        return embeddings

def embed_query(query, embedding_model_name='sentence-transformers/all-MiniLM-L6-v2'):
    tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
    model = AutoModel.from_pretrained(embedding_model_name)
    tokens = tokenizer(query, return_tensors="pt")
    with torch.no_grad():
        embedding = model(**tokens).last_hidden_state.mean(dim=1).numpy()
    return embedding

def find_best_route(query, routes):
    query_embedding = embed_query(query)
    best_match_route = None
    highest_similarity = -1
    
    for route in routes:
        similarities = cosine_similarity(query_embedding, route.embeddings).flatten()
        max_similarity = np.max(similarities)
        
        if max_similarity > highest_similarity:
            highest_similarity = max_similarity
            best_match_route = route
            
    return best_match_route

# example routing
fishing = Route(
    name="fishing",
    utterances=[
        "What's the best bait for catching bass?",
        "Do you prefer freshwater or saltwater fishing?",
        "What's your favorite fishing spot?",
        "Have you ever caught a really big fish?",
        "Any tips for a beginner fisherman?",
    ],
)

hunting = Route(
    name="hunting",
    utterances=[
        "What's the best time of year for deer hunting?",
        "Do you use a bow or a rifle?",
        "What's your most memorable hunting trip?",
        "How do you track game in the wild?",
        "Any tips for staying safe while hunting?",
        "Ducks hunting tips"
    ],
)

camping = Route(
    name="camping",
    utterances=[
        "What's your favorite camping spot?",
        "Do you prefer tents or RVs for camping?",
        "How do you make a campfire?",
        "What's your go-to camping meal?",
        "Any tips for a first-time camper?",
    ],
)

routes = [fishing, hunting, camping]

query = "I am looking for a sea near-shore location for hunting ducks"
best_route = find_best_route(query, routes)
print(f"THE BEST ROUTE: {best_route.name}")

THE BEST ROUTE: hunting


## 4. Zero Shot Classification Router

The implementation can be found on the Haystack GitHub [here](https://github.com/deepset-ai/haystack/blob/main/haystack/components/routers/zero_shot_text_router.py#L130) 🙃

## 5. Language Classification Router

Practically, there are two options how to establish routing based on multiple languages.

- **Option 1**: Utilize external services for language detection (e.g. Azure Speech)
- **Option 2**: Do the translation and routing via Prompt Engineering (example below)

In [37]:
# design prompt

prompt = PromptTemplate(
    template="""

    You are a brilliant assistant who's exceptional in language identification tasks.
    Your main task is to identify the language of the user's query below and respond using one of the ISO 639 langauge codes.

    Do not respond with more than one word.

    <ISO codes>
    {iso_codes}
    </ISO codes>

    <user query>
    {user_query}
    </user query>

    Language:
    """,
    input_variables=["iso_codes", "user_query"],
)

iso_639_languages = {
    "English": "en",
    "Mandarin Chinese": "zh",
    "Hindi": "hi",
    "Spanish": "es",
    "French": "fr",
    "German": "de",
    "Standard Arabic": "ar",
    "Bengali": "bn",
    "Portuguese": "pt",
    "Russian": "ru",
    "Japanese": "ja"
}

In [38]:
# query

query = "Was macht man am Freitag Abend in Berlin?" # german

In [39]:
# language router

llama = ChatOllama(model="llama3", temperature=0)

completion_route_chain = prompt | llama | StrOutputParser()

input_data = {
    "iso_codes": iso_639_languages,
    "user_query": query   
}

route = completion_route_chain.invoke(input=input_data)
print(f"ROUTE: {route}")

ROUTE: de


## 6. Keyword Router

A keyword router will select a route by matching **keywords** between the **user's query** and **routes list**. In some specific use cases, we only need a couple of keywords to route the query to a specific module or handler. 

Why do we need to make extra LLM calls, if we can save some **latency** and **extra money**?!

In [83]:
# OPTION 1: simple keyword router

class KeywordRouter:
    def __init__(self, routes):
        self.routes = routes

    def find_keyword_route(self, query):
        query_lower = query.lower()
        for route, keywords in self.routes.items():
            if any(keyword in query_lower for keyword in keywords):
                return route
        return "default"

# define routes --> better descriptions = better routing
routes = {
    "web": ["html", "css", "javascript", "web", "website", "frontend", "backend"],
    "blockchain": ["blockchain", "crypto", "bitcoin", "ethereum", "smart contract", "decentralized"],
    "opensource": ["open-source", "open source", "github", "git", "contribution", "license"],
}

user_query = "How to be a frontend developer?"

keyword_router = KeywordRouter(routes=routes)

route = keyword_router.find_keyword_route(query=user_query)
print(f"ROUTE: {route}")

ROUTE: web


In [44]:
# OPTION 2: keyword router w/ retrieval --> 1st step is to create retrievers (simulation of multiple routes --> web/blockchain/opensource)

emb_model = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")

In [62]:
# preprocessing

data_folder = "../data/rag-routing"
documents = {
    "blockchain": os.path.join(data_folder, "blockchain.pdf"),
    "opensource": os.path.join(data_folder, "opensource.pdf"),
    "web": os.path.join(data_folder, "web.pdf")
}

loaders = {name: PyPDFLoader(path) for name, path in documents.items()}
docs_content = {name: loader.load() for name, loader in loaders.items()}
text_splitter = NLTKTextSplitter()
chunked_docs = {name: text_splitter.split_documents(content) for name, content in docs_content.items()}
print(f"TOTAL CHUNKS FOR WEB: {len(chunked_docs.get('web'))}")
print(f"TOTAL CHUNKS FOR OPENSOURCE: {len(chunked_docs.get('opensource'))}")
print(f"TOTAL CHUNKS FOR BLOCKCHAIN: {len(chunked_docs.get('blockchain'))}")

TOTAL CHUNKS FOR WEB: 49
TOTAL CHUNKS FOR OPENSOURCE: 15
TOTAL CHUNKS FOR BLOCKCHAIN: 35


In [172]:
for name, chunks in chunked_docs.items():
    print(chunks)

[Document(page_content='BlockChain Technology  \nBeyond\xa0Bitcoin \n \nAbstract \n \nAblockchainisessentiallyadistributeddatabaseofrecordsorpublicledgerofalltransactionsor \xa0 \xa0\xa0 \xa0\xa0 \xa0 \xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0 \xa0\xa0\ndigitaleventsthathavebeenexecutedandsharedamongparticipatingparties.Eachtransactionin \xa0\xa0\xa0\xa0\xa0 \xa0\xa0\xa0\xa0 \xa0\xa0\xa0 \xa0\xa0\nthepublicledgerisverifiedbyconsensusofamajorityoftheparticipantsinthesystem.And,once \xa0\xa0\xa0\xa0\xa0\xa0 \xa0\xa0\xa0 \xa0\xa0\xa0 \xa0\xa0\xa0\xa0\xa0\xa0\nentered,informationcanneverbeerased.Theblockchaincontainsacertainandverifiablerecordof \xa0 \xa0\xa0\xa0\xa0\xa0\xa0 \xa0 \xa0\xa0\xa0\xa0 \xa0\xa0\xa0\neverysingletransactionevermade.Bitcoin,thedecentralizedpeer\xadto\xadpeerdigitalcurrency,isthe \xa0\xa0 \xa0\xa0\xa0\xa0\xa0 \xa0 \xa0\xa0 \xa0\xa0\xa0\nmostpopularexamplethatusesblockchaintechnology.Thedigitalcurrencybitcoinitselfishighly \xa0 \xa0 \xa0\xa0\xa0 \xa0 \xa0\xa0\xa0 \xa0\xa0\xa0\

In [77]:
# create vector stores for each doc --> simulation of different pipelines/handlers/search indexes, etc.

emb_model = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
vector_stores = {}
for name, chunks in chunked_docs.items():
    db = Chroma.from_documents(documents=chunks, embedding=emb_model)
    vector_stores[name] = db

In [82]:
# create retrievers

retrievers = {name: chroma_db.as_retriever(search_type="mmr", search_kwargs={"k": 3}) for name, chroma_db in vector_stores.items()}

web_retriever = retrievers.get('web') # route 1
blockchain_retriever = retrievers.get('blockchain') # route 2
opensource_retriever = retrievers.get('opensource') # route 3

In [170]:
web_results = web_retriever.get_relevant_documents(user_query)
web_results

[Document(page_content='About DigitalOcean \nDigitalOcean .simplifies .cloud .computing .so .\n\ndevelopers .\n\nand .\n\nbusinesses.\n\ncan\n.\n\nspend\n.\n\nmore\n.\n\ntime\n.\n\nbuilding\n.\n\nsoftware\n.\n\nthat\n.changes .\n\nthe\n.\n\nworld.\n\nWith its \nmission-critical infrastructure and fully managed offerings, DigitalOcean helps \ndevelopers, startups, and small- and medium-sized businesses (SMBs) rapidly build, deploy, and scale applications to accelerate innovation and increase productivity and agility.\n\nDigitalOcean combines the power of simplicity, community, open source, and customer support so customers can spend less time managing their infrastructure and more time building innovative applications that drive business growth.\n\nTo get started, sign .up.for.an.account .at.Digital Ocean.com.\n\nFor more information or help migrating your infrastructure \nto DigitalOcean, speak\n.\n\nto\n.\n\na\n.\n\nsales\n.\n\nrepresentative.', metadata={'page': 14, 'source': '../dat

In [165]:
user_query = " DigitalOcean Opensource models"

embedding_model = SentenceTransformer('thenlper/gte-large')

query_embedding = embedding_model.encode([query])

web_results = web_retriever.get_relevant_documents(user_query)
blockchain_results = blockchain_retriever.get_relevant_documents(user_query)
opensource_results = opensource_retriever.get_relevant_documents(user_query)


web_texts = [doc.page_content for doc in web_results] 
blockchain_texts = [doc.page_content for doc in blockchain_results] 
opensource_texts = [doc.page_content for doc in opensource_results] 

web_embeddings = embedding_model.encode(web_texts)
blockchain_embeddings = embedding_model.encode(blockchain_texts) 
opensource_embeddings = embedding_model.encode(opensource_texts) 

print(web_results)

print(blockchain_results)

# web_sim = cosine_similarity(query_embedding, web_embeddings)
# print("WEB:", web_sim)
# blockchain_sim = cosine_similarity(query_embedding, blockchain_embeddings)
# print("BLOCKCHAIN:", blockchain_sim)
# opensource_sim = cosine_similarity(query_embedding, opensource_embeddings)
# print("OPENSOURCE:", opensource_sim)

[Document(page_content='About DigitalOcean \nDigitalOcean .simplifies .cloud .computing .so .\n\ndevelopers .\n\nand .\n\nbusinesses.\n\ncan\n.\n\nspend\n.\n\nmore\n.\n\ntime\n.\n\nbuilding\n.\n\nsoftware\n.\n\nthat\n.changes .\n\nthe\n.\n\nworld.\n\nWith its \nmission-critical infrastructure and fully managed offerings, DigitalOcean helps \ndevelopers, startups, and small- and medium-sized businesses (SMBs) rapidly build, deploy, and scale applications to accelerate innovation and increase productivity and agility.\n\nDigitalOcean combines the power of simplicity, community, open source, and customer support so customers can spend less time managing their infrastructure and more time building innovative applications that drive business growth.\n\nTo get started, sign .up.for.an.account .at.Digital Ocean.com.\n\nFor more information or help migrating your infrastructure \nto DigitalOcean, speak\n.\n\nto\n.\n\na\n.\n\nsales\n.\n\nrepresentative.', metadata={'page': 14, 'source': '../dat

In [160]:


class RetrieverRouterKeywordEngine:
    def __init__(self, retrievers):
        """Initializes the keyword engine with a dictionary of retrievers.
        
        Args:
            retrievers (dict): A dictionary where keys are retriever names and values are retriever instances.
        """
        self.retrievers = retrievers
        self.retrievers_names = list(retrievers.keys())
        self.embedding_model = SentenceTransformer('thenlper/gte-large')

    def get_retriever(self, query):
        """Determines the best route based on the highest similarity score between query and retrieved documents.
        
        Args:
            query (str): User's query string.

        Returns:
            dict: Contains the name of the best retriever and the documents that led to its selection.
        """
        query_embedding = self.embedding_model.encode([query])
        max_similarity = -1
        best_retriever_key = None
        best_docs = None

        for name, retriever in self.retrievers.items():
            retrieved_docs = retriever.get_relevant_documents(query)
            if not retrieved_docs:
                continue
            
            doc_texts = [doc.page_content for doc in retrieved_docs] 
            doc_embeddings = self.embedding_model.encode(doc_texts) 
            similarities = cosine_similarity(query_embedding, doc_embeddings)
            print(f"Similarity: {similarities}")
            avg_similarity = np.mean(similarities)

            if avg_similarity > max_similarity:
                max_similarity = avg_similarity
                best_retriever_key = name
                best_docs = retrieved_docs

        if best_retriever_key:
            return {
                "retriever": best_retriever_key,
                "documents": best_docs
            }
        else:
            return None

In [161]:
router = RetrieverRouterKeywordEngine(retrievers)

print('ROUTES: ', router.retrievers_names)

# example usage
user_query = """DigitalOcean"""

retriever = router.get_retriever(user_query)
print('OPTIMAL ROUTE:', retriever['retriever'])
for doc in retriever['documents']:
    print(doc)

ROUTES:  ['blockchain', 'opensource', 'web']
Similarity: [[0.91160697 0.8148801  0.79332614]]
Similarity: [[0.91160697 0.8148801  0.79332614]]
Similarity: [[0.91160697 0.8148801  0.79332614]]
OPTIMAL ROUTE: blockchain
page_content='About DigitalOcean \nDigitalOcean .simplifies .cloud .computing .so .\n\ndevelopers .\n\nand .\n\nbusinesses.\n\ncan\n.\n\nspend\n.\n\nmore\n.\n\ntime\n.\n\nbuilding\n.\n\nsoftware\n.\n\nthat\n.changes .\n\nthe\n.\n\nworld.\n\nWith its \nmission-critical infrastructure and fully managed offerings, DigitalOcean helps \ndevelopers, startups, and small- and medium-sized businesses (SMBs) rapidly build, deploy, and scale applications to accelerate innovation and increase productivity and agility.\n\nDigitalOcean combines the power of simplicity, community, open source, and customer support so customers can spend less time managing their infrastructure and more time building innovative applications that drive business growth.\n\nTo get started, sign .up.for.an.ac