RAG Fusion with Local LLM.


Install python packages required for langchain

In [None]:
! pip install langchain langchain-core langchain_community tiktoken langchainhub chromadb langchain unstructured sentence-transformers pytesseract unstructured_pytesseract tesseract

Set up os environment variables.
This is to enable visibility of tracing langchain invocations on smith.langchain.com

In [None]:
import os
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ['LANGCHAIN_API_KEY'] = "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"
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

Set up indexing for the vector store db

In [None]:
from langchain_community.document_loaders import DirectoryLoader, UnstructuredPDFLoader
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import (SentenceTransformerEmbeddings,)
loader = DirectoryLoader("./documents/markdown", glob="**/*.md", show_progress=True, loader_cls=UnstructuredFileLoader)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
pdf_loader = DirectoryLoader('./documents/pdf', glob="**/*.pdf", show_progress=True, loader_cls=UnstructuredPDFLoader)
pdf_docs = pdf_loader.load()
pdf_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(documents) + pdf_splitter.split_documents(pdf_docs)
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(documents=split_docs,
                                    embedding=embedding_function,
                                    persist_directory="./db")
retriever = vectorstore.as_retriever()

Query the vector store to retrieve query similar documents

In [None]:
query = "What is the code A_100?"
print(vectorstore)
docs = vectorstore.similarity_search(query)
for doc in docs:
    print(f"Document source: {doc.metadata}")
    print(f"Document page_content: {doc.page_content}\n")
    print(f"--------------------------------------------")

Generate multiple search queries based on the user's input question. This will then be used to retrieve documents from the vectorstore related to the question.

In [None]:
from langchain_core.messages import AIMessage
from langchain_community.chat_models import ChatOllama
from langchain.prompts import ChatPromptTemplate

# Template to ask LLM to generate 3 queries
template = """You are a helpful assistant that generates multiple sub-questions related to an input question.
The goal is to break down the input into a set of sub-problems / sub-questions that can be answers in isolation.
Generate multiple search queries related to: {question}

You must only generate 3 queries. No more than 3 is allowed.

Example:
The three queries are (3 queries):
1. This is the first query.
2. This is the second query.
3. This is the third query.

The three queries are (3 queries):"""
prompt_rag_fusion = ChatPromptTemplate.from_template(template)
from langchain_core.output_parsers import StrOutputParser

# Generates 3 queries based on the initial question to be used for RAG. This ensures we're more likely to retrieve the 
# document related to the original question.
# We are retrying max_attempts times because sometimes the LLM returns empty output.
def generate_queries(question):
    max_attempts = 5
    generated_queries = []
    attempt = 0
    while len(generated_queries) <= 3 and attempt < max_attempts:
        llm_response: AIMessage = ChatOllama(model="llama3").invoke(prompt_rag_fusion.format(question=question))
        generated_queries = StrOutputParser().parse(text=llm_response.content).split("\n")
        generate_queries = list(filter(lambda item: item.strip(), generated_queries))
        attempt += 1
    return generate_queries



Experiment to generating multiple queries. This is not used further down.

In [None]:
# Try it out here. But this is not used further down.
print(generate_queries("What is the code A_100?"))


Retrieve the related documents to the three queries provided by the LLM. 
Perform ranking of the retrieved documents.

In [None]:
from langchain.load import dumps, loads


def reciprocal_rank_fusion(results: list[list], k=60):
    fused_scores = {}
    for docs in results:
        for rank, doc in enumerate(docs):
            doc_str = dumps(doc)
            if doc_str not in fused_scores:
                fused_scores[doc_str] = 0
            previous_score = fused_scores[doc_str]
            fused_scores[doc_str] += 1 / (rank + k)
    reranked_results = [
        (loads(doc), score)
        for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    ]
    return reranked_results

retrieval_fusion_chain = generate_queries | retriever.map() | reciprocal_rank_fusion


Experiment with document retrieval. This is not used further down

In [None]:
# Try it out. But it will not be used further down.
docs = retrieval_fusion_chain.invoke({"question":"What is the code A_100?"})
print(f"Number of docs retrieved: {len(docs)}\n")
for doc in docs:
    print(f"Document score: {doc[1]}")
    print(f"Document source: {doc[0].metadata}")
    print(f"Document page content:")
    print(f"\t{doc[0].page_content}")
    print(f"--------------------------------------------\n")

Pass the documents back into the context for the LLM and include the original question

In [None]:
from operator import itemgetter

# RAG
template = """Answer the following question based on this context:

{context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

def generate_answer(question):
    max_attempts = 3
    response = ''
    attempt = 0
    while not response and attempt < max_attempts:
        llm_response: AIMessage = ChatOllama(model="llama3").invoke(prompt.format(context=retrieval_fusion_chain.invoke({"question": question}), question=question))
        response = StrOutputParser().parse(text=llm_response.content)
        attempt += 1
    return response


Try it out!

In [None]:
print(generate_answer("What is the code A_100?"))