**Reference Link:** [RAG Systems Essentials (Analytics Vidhya)](https://courses.analyticsvidhya.com/courses/take/rag-systems-essentials/lessons/60148017-hands-on-deep-dive-into-rag-evaluation-metrics-generator-metrics-i)

# Build a Contextual Retrieval based RAG System

## Install OpenAI, and LangChain dependencies

In [0]:
!pip install langchain==0.3.10
!pip install langchain-openai==0.2.12
!pip install langchain-community==0.3.11
!pip install jq==1.8.0
!pip install pymupdf==1.25.1

## Install Chroma Vector DB and LangChain wrapper

In [0]:
!pip install langchain-chroma==0.1.4

## Enter Open AI API Key

In [0]:
from getpass import getpass

OPENAI_KEY = getpass('Enter Open AI API Key: ')

## Setup Environment Variables

In [0]:
import os

os.environ['OPENAI_API_KEY'] = OPENAI_KEY

### Open AI Embedding Models

LangChain enables us to access Open AI embedding models which include the newest models: a smaller and highly efficient `text-embedding-3-small` model, and a larger and more powerful `text-embedding-3-large` model.

In [0]:
from langchain_openai import OpenAIEmbeddings

# details here: https://openai.com/blog/new-embedding-models-and-api-updates
openai_embed_model = OpenAIEmbeddings(model='text-embedding-3-small')

## Loading and Processing the Data

### Get the dataset

In [0]:
# if you can't download using the following code
# go to https://drive.google.com/file/d/1aZxZejfteVuofISodUrY2CDoyuPLYDGZ download it
# manually upload it on colab
!gdown 1aZxZejfteVuofISodUrY2CDoyuPLYDGZ

In [0]:
!unzip rag_docs.zip

### Load and Process JSON Documents

In [0]:
from langchain.document_loaders import JSONLoader

loader = JSONLoader(file_path='./rag_docs/wikidata_rag_demo.jsonl',
                    jq_schema='.',
                    text_content=False,
                    json_lines=True)
wiki_docs = loader.load()

In [0]:
len(wiki_docs)

In [0]:
wiki_docs[3]

In [0]:
import json
from langchain.docstore.document import Document
wiki_docs_processed = []

for doc in wiki_docs:
    doc = json.loads(doc.page_content)
    metadata = {
        "title": doc['title'],
        "id": doc['id'],
        "source": "Wikipedia",
        "page": 1
    }
    data = ' '.join(doc['paragraphs'])
    wiki_docs_processed.append(Document(page_content=data, metadata=metadata))

In [0]:
wiki_docs_processed[3]

### Load and Process PDF documents

#### Create Chunk Contexts for Contextual Retrieval

![](https://i.imgur.com/LRhKHzk.png)

In [0]:
from langchain_openai import ChatOpenAI

chatgpt = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)

In [0]:
# create chunk context generation chain
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser


def generate_chunk_context(document, chunk):

    chunk_process_prompt = """You are an AI assistant specializing in research paper analysis.
                            Your task is to provide brief, relevant context for a chunk of text
                            based on the following research paper.

                            Here is the research paper:
                            <paper>
                            {paper}
                            </paper>

                            Here is the chunk we want to situate within the whole document:
                            <chunk>
                            {chunk}
                            </chunk>

                            Provide a concise context (3-4 sentences max) for this chunk,
                            considering the following guidelines:

                            - Give a short succinct context to situate this chunk within the overall document
                            for the purposes of improving search retrieval of the chunk.
                            - Answer only with the succinct context and nothing else.
                            - Context should be mentioned like 'Focuses on ....'
                            do not mention 'this chunk or section focuses on...'

                            Context:
                        """

    prompt_template = ChatPromptTemplate.from_template(chunk_process_prompt)

    agentic_chunk_chain = (prompt_template
                                |
                            chatgpt
                                |
                            StrOutputParser())

    context = agentic_chunk_chain.invoke({'paper': document, 'chunk': chunk})

    return context

In [0]:
from langchain.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import uuid

def create_contextual_chunks(file_path, chunk_size=3500, chunk_overlap=0):

    print('Loading pages:', file_path)
    loader = PyMuPDFLoader(file_path)
    doc_pages = loader.load()

    print('Chunking pages:', file_path)
    splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
                                              chunk_overlap=chunk_overlap)
    doc_chunks = splitter.split_documents(doc_pages)

    print('Generating contextual chunks:', file_path)
    original_doc = '\n'.join([doc.page_content for doc in doc_chunks])
    contextual_chunks = []
    for chunk in doc_chunks:
        chunk_content = chunk.page_content
        chunk_metadata = chunk.metadata
        chunk_metadata_upd = {
            'id': str(uuid.uuid4()),
            'page': chunk_metadata['page'],
            'source': chunk_metadata['source'],
            'title': chunk_metadata['source'].split('/')[-1]
        }
        context = generate_chunk_context(original_doc, chunk_content)
        contextual_chunks.append(Document(page_content=context+'\n'+chunk_content,
                                          metadata=chunk_metadata_upd))
    print('Finished processing:', file_path)
    print()
    return contextual_chunks

In [0]:
from glob import glob

pdf_files = glob('./rag_docs/*.pdf')
pdf_files

In [0]:
paper_docs = []
for fp in pdf_files:
    paper_docs.extend(create_contextual_chunks(file_path=fp, chunk_size=3500))

In [0]:
len(paper_docs)

In [0]:
paper_docs[0]

### Combine all document chunks in one list

In [0]:
len(wiki_docs_processed)

In [0]:
total_docs = wiki_docs_processed + paper_docs
len(total_docs)

## Index Document Chunks and Embeddings in Vector DB

Here we initialize a connection to a Chroma vector DB client, and also we want to save to disk, so we simply initialize the Chroma client and pass the directory where we want the data to be saved to.

In [0]:
from langchain_chroma import Chroma

# create vector DB of docs and embeddings - takes < 30s on Colab
chroma_db = Chroma.from_documents(documents=total_docs,
                                  collection_name='my_context_db',
                                  embedding=openai_embed_model,
                                  # need to set the distance function to cosine else it uses euclidean by default
                                  # check https://docs.trychroma.com/guides#changing-the-distance-function
                                  collection_metadata={"hnsw:space": "cosine"},
                                  persist_directory="./my_context_db")

### Load Vector DB from disk

This is just to show once you have a vector database on disk you can just load and create a connection to it anytime

In [0]:
# load from disk
chroma_db = Chroma(persist_directory="./my_context_db",
                   collection_name='my_context_db',
                   embedding_function=openai_embed_model)

In [0]:
chroma_db

### Semantic Similarity based Retrieval

We use simple cosine similarity here and retrieve the top 5 similar documents based on the user input query

In [0]:
similarity_retriever = chroma_db.as_retriever(search_type="similarity",
                                              search_kwargs={"k": 5})

In [0]:
from IPython.display import display, Markdown

def display_docs(docs):
    for doc in docs:
        print('Metadata:', doc.metadata)
        print('Content Brief:')
        display(Markdown(doc.page_content[:1000]))
        print()

In [0]:
query = "what is machine learning?"
top_docs = similarity_retriever.invoke(query)
display_docs(top_docs)

In [0]:
query = "what is the difference between transformers and vision transformers?"
top_docs = similarity_retriever.invoke(query)
display_docs(top_docs)

## Build the RAG Pipeline

In [0]:
from langchain_core.prompts import ChatPromptTemplate

rag_prompt = """You are an assistant who is an expert in question-answering tasks.
                Answer the following question using only the following pieces of retrieved context.
                If the answer is not in the context, do not make up answers, just say that you don't know.
                Keep the answer detailed and well formatted based on the information from the context.

                Question:
                {question}

                Context:
                {context}

                Answer:
            """

rag_prompt_template = ChatPromptTemplate.from_template(rag_prompt)

In [0]:
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

chatgpt = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

qa_rag_chain = (
    {
        "context": (similarity_retriever
                      |
                    format_docs),
        "question": RunnablePassthrough()
    }
      |
    rag_prompt_template
      |
    chatgpt
)

In [0]:
from IPython.display import display, Markdown

query = "What is machine learning?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "What is a CNN?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "How is a resnet better than a CNN?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "What is NLP and its relation to linguistics?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "What is the difference between AI, ML and DL?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "What is the difference between transformers and vision transformers?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "How is self-attention important in transformers?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "How does a resnet work?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "What is LangGraph?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "What is an Agentic AI System?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

In [0]:
query = "What is LangChain?"
result = qa_rag_chain.invoke(query)
display(Markdown(result.content))

# Build a RAG System with Sources

In [0]:
from langchain_core.prompts import ChatPromptTemplate

rag_prompt = """You are an assistant who is an expert in question-answering tasks.
                Answer the following question using only the following pieces of retrieved context.
                If the answer is not in the context, do not make up answers, just say that you don't know.
                Keep the answer detailed and well formatted based on the information from the context.

                Question:
                {question}

                Context:
                {context}

                Answer:
            """

rag_prompt_template = ChatPromptTemplate.from_template(rag_prompt)

In [0]:
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from operator import itemgetter


chatgpt = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

src_rag_response_chain = (
    {
        "context": (itemgetter('context')
                        |
                    RunnableLambda(format_docs)),
        "question": itemgetter("question")
    }
        |
    rag_prompt_template
        |
    chatgpt
        |
    StrOutputParser()
)

rag_chain_w_sources = (
    {
        "context": similarity_retriever,
        "question": RunnablePassthrough()
    }
        |
    RunnablePassthrough.assign(response=src_rag_response_chain)
)

In [0]:
query = "What is machine learning?"
result = rag_chain_w_sources.invoke(query)
result

In [0]:
from IPython.display import display, Markdown

def display_results(result_obj):
    print('Query:')
    display(Markdown(result_obj['question']))
    print()
    print('Response:')
    display(Markdown(result_obj['response']))
    print('='*50)
    print('Sources:')
    for source in result_obj['context']:
        print('Metadata:', source.metadata)
        print('Content Brief:')
        display(Markdown(source.page_content))
        print()


In [0]:
query = "What is machine learning?"
result = rag_chain_w_sources.invoke(query)
display_results(result)

In [0]:
query = "What is the difference between AI, ML and DL?"
result = rag_chain_w_sources.invoke(query)
display_results(result)

In [0]:
query = "What is the difference between transformers and vision transformers?"
result = rag_chain_w_sources.invoke(query)
display_results(result)

In [0]:
query = "What is an Agentic AI System?"
result = rag_chain_w_sources.invoke(query)
display_results(result)

# Build a RAG System with Source Citations Agentic Pipeline

In [0]:
from langchain_core.prompts import ChatPromptTemplate

rag_prompt = """You are an assistant who is an expert in question-answering tasks.
                Answer the following question using only the following pieces of retrieved context.
                If the answer is not in the context, do not make up answers, just say that you don't know.
                Keep the answer detailed and well formatted based on the information from the context.

                Question:
                {question}

                Context:
                {context}

                Answer:
            """

rag_prompt_template = ChatPromptTemplate.from_template(rag_prompt)
rag_prompt_template.pretty_print()

In [0]:
citations_prompt = """You are an assistant who is an expert in analyzing answers to questions
                      and finding out referenced citations from context articles.

                      Given the following question, context and generated answer,
                      analyze the generated answer and quote citations from context articles
                      that can be used to justify the generated answer.

                      Question:
                      {question}

                      Context Articles:
                      {context}

                      Answer:
                      {answer}
                  """

cite_prompt_template = ChatPromptTemplate.from_template(citations_prompt)
cite_prompt_template.pretty_print()

In [0]:
from pydantic import BaseModel, Field
from typing import List

class Citation(BaseModel):
    id: str = Field(description="""The string ID of a SPECIFIC context article
                                   which justifies the answer.""")
    source: str = Field(description="""The source of the SPECIFIC context article
                                       which justifies the answer.""")
    title: str = Field(description="""The title of the SPECIFIC context article
                                      which justifies the answer.""")
    page: int = Field(description="""The page number of the SPECIFIC context article
                                     which justifies the answer.""")
    quotes: str = Field(description="""The VERBATIM sentences from the SPECIFIC context article
                                      that are used to generate the answer.
                                      Should be exact sentences from context article without missing words.""")


class QuotedCitations(BaseModel):
    """Quote citations from given context articles
       that can be used to justify the generated answer. Can be multiple articles."""
    citations: List[Citation] = Field(description="""Citations (can be multiple) from the given
                                                     context articles that justify the answer.""")

In [0]:
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from operator import itemgetter


chatgpt = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
structured_chatgpt = chatgpt.with_structured_output(QuotedCitations)


def format_docs_with_metadata(docs: List[Document]) -> str:
    formatted_docs = [
        f"""Context Article ID: {doc.metadata['id']}
            Context Article Source: {doc.metadata['source']}
            Context Article Title: {doc.metadata['title']}
            Context Article Page: {doc.metadata['page']}
            Context Article Details: {doc.page_content}
         """
            for i, doc in enumerate(docs)
    ]
    return "\n\n" + "\n\n".join(formatted_docs)

rag_response_chain = (
    {
        "context": (itemgetter('context')
                        |
                    RunnableLambda(format_docs_with_metadata)),
        "question": itemgetter("question")
    }
        |
    rag_prompt_template
        |
    chatgpt
        |
    StrOutputParser()
)

cite_response_chain = (
    {
        "context": itemgetter('context'),
        "question": itemgetter("question"),
        "answer": itemgetter("answer")
    }
        |
    cite_prompt_template
        |
    structured_chatgpt
)

rag_chain_w_citations = (
    {
        "context": similarity_retriever,
        "question": RunnablePassthrough()
    }
        |
    RunnablePassthrough.assign(answer=rag_response_chain)
        |
    RunnablePassthrough.assign(citations=cite_response_chain)
)

In [0]:
query = "What is machine learning"
result = rag_chain_w_citations.invoke(query)
result

In [0]:
result['citations'].dict()['citations']

In [0]:
import re
# used mostly for nice display formatting, ignore if not needed
def get_cited_context(result_obj):
    # Dictionary to hold separate citation information for each unique source and title combination
    source_with_citations = {}

    def highlight_text(context, quote):
        # Normalize whitespace and remove unnecessary punctuation
        quote = re.sub(r'\s+', ' ', quote).strip()
        context = re.sub(r'\s+', ' ', context).strip()

        # Split quote into phrases, being careful with punctuation
        phrases = [phrase.strip() for phrase in re.split(r'[.!?]', quote) if phrase.strip()]

        highlighted_context = context

        for phrase in phrases: # for each quoted phrase

            # Create regex pattern to match cited phrases
            # Escape special regex characters, but preserve word boundaries
            escaped_phrase = re.escape(phrase)
            # Create regex pattern that allows for slight variations
            pattern = re.compile(r'\b' + escaped_phrase + r'\b', re.IGNORECASE)

            # Replace all matched phrases with bolded version
            highlighted_context = pattern.sub(lambda m: f"**{m.group(0)}**", highlighted_context)

        return highlighted_context

    # Process the citation data
    for cite in result_obj['citations'].dict()['citations']:
        cite_id = cite['id']
        title = cite['title']
        source = cite['source']
        page = cite['page']
        quote = cite['quotes']

        # Check if the (source, title) key exists, and initialize if it doesn't
        if (source, title) not in source_with_citations:
            source_with_citations[(source, title)] = {
                'title': title,
                'source': source,
                'citations': []
            }

        # Find or create the citation entry for this unique (id, page) combination
        citation_entry = next(
            (c for c in source_with_citations[(source, title)]['citations'] if c['id'] == cite_id and c['page'] == page),
            None
        )
        if citation_entry is None:
            citation_entry = {'id': cite_id, 'page': page, 'quote': [quote], 'context': None}
            source_with_citations[(source, title)]['citations'].append(citation_entry)
        else:
            citation_entry['quote'].append(quote)

    # Process context data
    for context in result_obj['context']:
        context_id = context.metadata['id']
        context_page = context.metadata['page']
        source = context.metadata['source']
        title = context.metadata['title']
        page_content = context.page_content

        # Match the context to the correct citation entry by source, title, id, and page
        if (source, title) in source_with_citations:
            for citation in source_with_citations[(source, title)]['citations']:
                if citation['id'] == context_id and citation['page'] == context_page:
                    # Apply highlighting for each quote in the citation's quote list
                    highlighted_content = page_content
                    for quote in citation['quote']:
                        highlighted_content = highlight_text(highlighted_content, quote)
                    citation['context'] = highlighted_content

    # Convert the dictionary to a list of dictionaries for separate entries
    final_result_list = [
        {
            'title': details['title'],
            'source': details['source'],
            'citations': details['citations']
        }
        for details in source_with_citations.values()
    ]

    return final_result_list


In [0]:
get_cited_context(result)

In [0]:
from IPython.display import display, Markdown

def display_results(result_obj):
    print('Query:')
    display(Markdown(result_obj['question']))
    print()
    print('Response:')
    display(Markdown(result_obj['answer']))
    print('='*50)
    print('Sources:')
    cited_context = get_cited_context(result_obj)
    for source in cited_context:
        print('Title:', source['title'], ' ', 'Source:', source['source'])
        print('Citations:')
        for citation in source['citations']:
            print('ID:', citation['id'], ' ', 'Page:', citation['page'])
            print('Cited Quotes:')
            display(Markdown('*'+' '.join(citation['quote'])+'*'))
            print('Cited Context:')
            display(Markdown(citation['context']))
            print()


In [0]:
display_results(result)

In [0]:
query = "What is AI, ML and DL?"
result = rag_chain_w_citations.invoke(query)
display_results(result)

In [0]:
query = "How is Machine learning related to supervised learning and clustering?"
result = rag_chain_w_citations.invoke(query)
display_results(result)

In [0]:
query = "What is the difference between transformers and vision transformers?"
result = rag_chain_w_citations.invoke(query)
display_results(result)