# RAG Fusion

In [1]:
# !pip -q install langchain huggingface_hub openai tiktoken pypdf
# !pip -q install google-generativeai chromadb unstructured
# !pip install sentence_transformers
# !pip -q install -U FlagEmbedding


### Download the Data & Utils

In [2]:
import os
import requests
import zipfile
from io import BytesIO
import textwrap
from langchain.llms import HuggingFaceHub
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.chat_models import ChatGooglePalm
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.load import dumps, loads
from operator import itemgetter

def download_and_extract_zip(url, target_folder):
    if not os.path.exists(target_folder):
        os.makedirs(target_folder)

    response = requests.get(url)
    if response.status_code != 200:
        raise Exception(f"Failed to download file: {url}")

    with zipfile.ZipFile(BytesIO(response.content)) as zip_ref:
        zip_ref.extractall(target_folder)

    print(f"Files extracted to {target_folder}")

In [3]:
text_files_url = "https://www.dropbox.com/scl/fi/av3nw07o5mo29cjokyp41/singapore_text_files_languages.zip?rlkey=xqdy5f1modtbnrzzga9024jyw&dl=1"
chroma_db_url = 'https://www.dropbox.com/scl/fi/3kep8mo77h642kvpum2p7/singapore_chroma_db.zip?rlkey=4ry4rtmeqdcixjzxobtmaajzo&dl=1'
text_files_folder = "singapore_text"
chroma_db_folder = "chroma_db"

In [4]:
download_and_extract_zip(text_files_url, text_files_folder)
download_and_extract_zip(chroma_db_url, '.')

Files extracted to singapore_text
Files extracted to .


In [5]:
# Set environment variable for HuggingFaceHub API token
os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"

## Load documents

In [6]:
loader = DirectoryLoader('singapore_text/Textfiles3/English/', glob="*.txt", show_progress=True)
docs = loader.load()

100%|████████████████████████████████████████████████████████████████████████████████| 646/646 [01:49<00:00,  5.88it/s]


## Concatenate and split text

In [7]:
raw_text = ''.join([doc.page_content for doc in docs if doc.page_content])
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len, is_separator_regex=False)
texts = text_splitter.split_text(raw_text)

## Load embeddings and vector database

In [8]:
model_name = "BAAI/bge-small-en-v1.5"
encode_kwargs = {'normalize_embeddings': True}
bge_embeddings = HuggingFaceBgeEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'}, encode_kwargs=encode_kwargs)

  from tqdm.autonotebook import tqdm, trange


In [9]:
db = Chroma(persist_directory="./chroma_db", embedding_function=bge_embeddings)

# Retriever and Chat Model setup

In [10]:
retriever = db.as_retriever(k=5)
model = ChatGooglePalm()

In [11]:
# Prompt template for RAG
template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

In [12]:
# Function to generate multiple queries
generate_queries_prompt = ChatPromptTemplate(input_variables=['question'], messages=[
    SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a helpful assistant that generates multiple search queries based on a single input query.')),
    HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='Generate multiple search queries related to: {question} \n OUTPUT (4 queries):'))
])
generate_queries = generate_queries_prompt | ChatGooglePalm(temperature=0) | StrOutputParser() | (lambda x: x.split("\n"))


## Reciprocal Rank Fusion function

In [13]:
# Reciprocal Rank Fusion function
def reciprocal_rank_fusion(results, k=60):
    fused_scores = {}
    for docs in results:
        for rank, doc in enumerate(docs):
            doc_str = dumps(doc)
            if doc_str not in fused_scores:
                fused_scores[doc_str] = 0
            fused_scores[doc_str] += 1 / (rank + k)
    reranked_results = [(loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)]
    return reranked_results

In [14]:
# RAG Fusion chain
ragfusion_chain = generate_queries | retriever.map() | reciprocal_rank_fusion

In [15]:
# Full RAG Fusion chain with prompt
full_rag_fusion_chain = (
    {
        "context": ragfusion_chain,
        "question": RunnablePassthrough()
    }
    | prompt
    | model
    | StrOutputParser()
)

In [16]:
query = "Tell me about Universal Studios Singapore?"
response = full_rag_fusion_chain.invoke({"question": query, "original_query": query})

Retrying langchain.chat_models.google_palm.chat_with_retry.<locals>._chat_with_retry in 2.0 seconds as it raised InvalidArgument: 400 API key not valid. Please pass a valid API key. [reason: "API_KEY_INVALID"
domain: "googleapis.com"
metadata {
  key: "service"
  value: "generativelanguage.googleapis.com"
}
].
Retrying langchain.chat_models.google_palm.chat_with_retry.<locals>._chat_with_retry in 4.0 seconds as it raised InvalidArgument: 400 API key not valid. Please pass a valid API key. [reason: "API_KEY_INVALID"
domain: "googleapis.com"
metadata {
  key: "service"
  value: "generativelanguage.googleapis.com"
}
].
Retrying langchain.chat_models.google_palm.chat_with_retry.<locals>._chat_with_retry in 8.0 seconds as it raised InvalidArgument: 400 API key not valid. Please pass a valid API key. [reason: "API_KEY_INVALID"
domain: "googleapis.com"
metadata {
  key: "service"
  value: "generativelanguage.googleapis.com"
}
].
Retrying langchain.chat_models.google_palm.chat_with_retry.<loca

InvalidArgument: 400 API key not valid. Please pass a valid API key. [reason: "API_KEY_INVALID"
domain: "googleapis.com"
metadata {
  key: "service"
  value: "generativelanguage.googleapis.com"
}
]

In [None]:
# Wrap and print the response
def wrap_text(text, width=90):
    lines = text.split('\n')
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
    return '\n'.join(wrapped_lines)

In [None]:
print(wrap_text(response))