In [1]:
%%capture
%pip install -U langchain
%pip install -U openai
%pip install -U ragas
%pip install -U arxiv
%pip install -U pymupdf
%pip install -U chromadb
%pip install -U tiktoken
%pip install -U accelerate
%pip install -U bitsandbytes
%pip install -U datasets
%pip install -U sentence_transformers
%pip install -U FlagEmbedding
%pip install -U ninja
%pip install -U flash_attn --no-build-isolation
%pip install -U tqdm
%pip install -U rank_bm25
%pip install -U transformers

In [3]:
import os
import openai
from getpass import getpass

openai.api_key = getpass("Please provide your OpenAI Key: ")
os.environ["OPENAI_API_KEY"] = openai.api_key

In [1]:
from langchain.document_loaders import ArxivLoader

base_docs = ArxivLoader(query="Retrieval Augmented Generation", load_max_docs=5).load()
len(base_docs)

5

In [7]:
from pprint import pprint
for doc in base_docs:
  pprint(doc)
  break

Document(page_content='A Survey on Retrieval-Augmented Text Generation\nHuayang Li♥,∗\nYixuan Su♠,∗\nDeng Cai♦,∗\nYan Wang♣,∗\nLemao Liu♣,∗\n♥Nara Institute of Science and Technology\n♠University of Cambridge\n♦The Chinese University of Hong Kong\n♣Tencent AI Lab\nli.huayang.lh6@is.naist.jp, ys484@cam.ac.uk\nthisisjcykcd@gmail.com, brandenwang@tencent.com\nlemaoliu@gmail.com\nAbstract\nRecently, retrieval-augmented text generation\nattracted increasing attention of the compu-\ntational linguistics community.\nCompared\nwith conventional generation models, retrieval-\naugmented text generation has remarkable ad-\nvantages and particularly has achieved state-of-\nthe-art performance in many NLP tasks. This\npaper aims to conduct a survey about retrieval-\naugmented text generation. It ﬁrstly highlights\nthe generic paradigm of retrieval-augmented\ngeneration, and then it reviews notable ap-\nproaches according to different tasks including\ndialogue response generation, machine trans-\nla

In [9]:
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=250)

docs = text_splitter.split_documents(base_docs)

vectorstore = Chroma.from_documents(docs, OpenAIEmbeddings())

In [41]:
print(len(docs))
print(docs[:2])

4756
[Document(page_content='A Survey on Retrieval-Augmented Text Generation\nHuayang Li♥,∗\nYixuan Su♠,∗\nDeng Cai♦,∗\nYan Wang♣,∗\nLemao Liu♣,∗\n♥Nara Institute of Science and Technology\n♠University of Cambridge\n♦The Chinese University of Hong Kong\n♣Tencent AI Lab', metadata={'Published': '2022-02-13', 'Title': 'A Survey on Retrieval-Augmented Text Generation', 'Authors': 'Huayang Li, Yixuan Su, Deng Cai, Yan Wang, Lemao Liu', 'Summary': 'Recently, retrieval-augmented text generation attracted increasing attention\nof the computational linguistics community. Compared with conventional\ngeneration models, retrieval-augmented text generation has remarkable\nadvantages and particularly has achieved state-of-the-art performance in many\nNLP tasks. This paper aims to conduct a survey about retrieval-augmented text\ngeneration. It firstly highlights the generic paradigm of retrieval-augmented\ngeneration, and then it reviews notable approaches according to different tasks\nincluding dia

In [42]:
print(max([len(chunk.page_content) for chunk in docs]))

249


In [12]:
base_retriever = vectorstore.as_retriever(search_kwargs={"k" : 2})

In [17]:
relevant_docs = base_retriever.get_relevant_documents("What is Retrieval Augmented Generation?")
relevant_docs

[Document(page_content='arXiv:2004.04906.\nPatrick Lewis, Ethan Perez, Aleksandra Piktus, Fabio\nPetroni, Vladimir Karpukhin, Naman Goyal, Hein-\nrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rock-\ntäschel, et al. 2020. Retrieval-augmented generation', metadata={'Authors': 'Xuchao Zhang, Menglin Xia, Camille Couturier, Guoqing Zheng, Saravan Rajmohan, Victor Ruhle', 'Published': '2023-08-08', 'Summary': "Retrieval augmented models show promise in enhancing traditional language\nmodels by improving their contextual understanding, integrating private data,\nand reducing hallucination. However, the processing time required for retrieval\naugmented large language models poses a challenge when applying them to tasks\nthat require real-time responses, such as composition assistance.\n  To overcome this limitation, we propose the Hybrid Retrieval-Augmented\nGeneration (HybridRAG) framework that leverages a hybrid setting that combines\nboth client and cloud models. HybridRAG incorporates retrie

In [70]:
from langchain.prompts import ChatPromptTemplate

template = """Answer the question based only on the following context. If you cannot answer the question with the context, please respond with 'I don't know':

### CONTEXT
{context}

### QUESTION
Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

In [80]:
from operator import itemgetter

from langchain_openai import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough, RunnableParallel

primary_qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

retriever =  RunnableParallel({"context":  base_retriever, "question":RunnablePassthrough()})

retrieval_augmented_qa_chain = retriever | {"response": prompt | primary_qa_llm, "context": itemgetter("context")}


In [81]:
question = "What is RAG?"

result = retrieval_augmented_qa_chain.invoke(question)

pprint(result)

{'context': [Document(page_content='2020). RAG consists of three primary components:\nTool Retrieval, Plan Generation, and Execution.1\nIn this study, we focus on enhancing tool retrieval,\nwith the goal of achieving subsequent improve-\nments in plan generation.', metadata={'Authors': 'Raviteja Anantha, Tharun Bethi, Danil Vodianik, Srinivas Chappidi', 'Published': '2023-12-09', 'Summary': "Large language models (LLMs) have the remarkable ability to solve new tasks\nwith just a few examples, but they need access to the right tools. Retrieval\nAugmented Generation (RAG) addresses this problem by retrieving a list of\nrelevant tools for a given task. However, RAG's tool retrieval step requires\nall the required information to be explicitly present in the query. This is a\nlimitation, as semantic search, the widely adopted tool retrieval method, can\nfail when the query is incomplete or lacks context. To address this limitation,\nwe propose Context Tuning for RAG, which employs a smart c

# Ground Truth Dataset Generation

In [33]:
from langchain.output_parsers import ResponseSchema
from langchain.output_parsers import StructuredOutputParser

question_schema = ResponseSchema(
    name="question",
    description="a question about the context."
)

question_response_schemas = [
    question_schema,
]

question_output_parser = StructuredOutputParser.from_response_schemas(question_response_schemas)
format_instructions = question_output_parser.get_format_instructions()

In [34]:
question_generation_llm = ChatOpenAI(model="gpt-3.5-turbo-16k")

bare_prompt_template = "{content}"
bare_template = ChatPromptTemplate.from_template(template=bare_prompt_template)

In [37]:
from langchain.prompts import ChatPromptTemplate

qa_template = """\
You are a University Professor creating a test for advanced students. For each context, create a question that is specific to the context. Avoid creating generic or general questions.

question: a question about the context.

Format the output as JSON with the following keys:
question

context: {context}
"""

prompt_template = ChatPromptTemplate.from_template(template=qa_template)

question_generation_chain = bare_template | question_generation_llm

from tqdm import tqdm

qac_triples = []

for text in tqdm(docs[:10]):
  messages = prompt_template.format_messages(
      context=text,
      format_instructions=format_instructions
  )
  response = question_generation_chain.invoke({"content" : messages})
  try:
    output_dict = question_output_parser.parse(response.content)
  except Exception as e:
    continue
  output_dict["context"] = text
  qac_triples.append(output_dict)

100%|██████████| 10/10 [00:44<00:00,  4.48s/it]


In [45]:
print(len(qac_triples))
qac_triples[:2]

10


[{'question': "What is the focus of the paper 'A Survey on Retrieval-Augmented Text Generation'?",
  'context': Document(page_content='A Survey on Retrieval-Augmented Text Generation\nHuayang Li♥,∗\nYixuan Su♠,∗\nDeng Cai♦,∗\nYan Wang♣,∗\nLemao Liu♣,∗\n♥Nara Institute of Science and Technology\n♠University of Cambridge\n♦The Chinese University of Hong Kong\n♣Tencent AI Lab', metadata={'Published': '2022-02-13', 'Title': 'A Survey on Retrieval-Augmented Text Generation', 'Authors': 'Huayang Li, Yixuan Su, Deng Cai, Yan Wang, Lemao Liu', 'Summary': 'Recently, retrieval-augmented text generation attracted increasing attention\nof the computational linguistics community. Compared with conventional\ngeneration models, retrieval-augmented text generation has remarkable\nadvantages and particularly has achieved state-of-the-art performance in many\nNLP tasks. This paper aims to conduct a survey about retrieval-augmented text\ngeneration. It firstly highlights the generic paradigm of retrieval

In [61]:
answer_generation_llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0)

answer_schema = ResponseSchema(
    name="answer",
    description="an answer to the question"
)

answer_response_schemas = [
    answer_schema,
]

answer_output_parser = StructuredOutputParser.from_response_schemas(answer_response_schemas)
format_instructions = answer_output_parser.get_format_instructions()

qa_template = """\
You are a University Professor creating a test for advanced students. For each question and context, create an answer.

answer: a answer about the context.

Format the output as JSON with the following keys:
answer

question: {question}
context: {context}
"""

prompt_template = ChatPromptTemplate.from_template(template=qa_template)

messages = prompt_template.format_messages(
    context=qac_triples[0]["context"],
    question=qac_triples[0]["question"],
    format_instructions=format_instructions
)

answer_generation_chain = bare_template | answer_generation_llm

response = answer_generation_chain.invoke({"content" : messages})
output_dict = answer_output_parser.parse(response.content)

In [62]:
response.content

'```json\n{\n  "answer": "The focus of the paper \'A Survey on Retrieval-Augmented Text Generation\' is to provide a comprehensive overview of the field of retrieval-augmented text generation, which has gained significant attention in the computational linguistics community. The paper highlights the generic paradigm of retrieval-augmented generation models, reviews notable approaches across various natural language processing tasks such as dialogue response generation and machine translation, and discusses the state-of-the-art performance achieved by these models. Additionally, the paper identifies and suggests important future research directions in this area.",\n  "question": "What is the focus of the paper \'A Survey on Retrieval-Augmented Text Generation\'?"\n}\n```'

In [47]:
for k, v in output_dict.items():
  print(k)
  print(v)

answer
The focus of the paper 'A Survey on Retrieval-Augmented Text Generation' is to provide a comprehensive overview of the recent advancements in retrieval-augmented text generation within the field of computational linguistics. It highlights the generic paradigm of retrieval-augmented generation models, reviews notable approaches across various natural language processing (NLP) tasks such as dialogue response generation and machine translation, and discusses the state-of-the-art performance achieved by these models. Additionally, the paper identifies and suggests important future research directions in this area.
question
What is the focus of the paper 'A Survey on Retrieval-Augmented Text Generation'?


In [63]:
for triple in tqdm(qac_triples):
  messages = prompt_template.format_messages(
      context=triple["context"],
      question=triple["question"],
      format_instructions=format_instructions
  )
  response = answer_generation_chain.invoke({"content" : messages})
  try:
    output_dict = answer_output_parser.parse(response.content)
  except Exception as e:
    continue
  triple["answer"] = output_dict["answer"]

100%|██████████| 10/10 [00:46<00:00,  4.60s/it]


In [64]:
import pandas as pd
from datasets import Dataset

ground_truth_qac_set = pd.DataFrame(qac_triples)
ground_truth_qac_set["context"] = ground_truth_qac_set["context"].map(lambda x: str(x.page_content))
ground_truth_qac_set = ground_truth_qac_set.rename(columns={"answer" : "ground_truth"})


eval_dataset = Dataset.from_pandas(ground_truth_qac_set)

  from .autonotebook import tqdm as notebook_tqdm


In [65]:
eval_dataset

Dataset({
    features: ['question', 'context', 'ground_truth'],
    num_rows: 10
})

In [66]:
eval_dataset[0]

{'question': "What is the focus of the paper 'A Survey on Retrieval-Augmented Text Generation'?",
 'context': 'A Survey on Retrieval-Augmented Text Generation\nHuayang Li♥,∗\nYixuan Su♠,∗\nDeng Cai♦,∗\nYan Wang♣,∗\nLemao Liu♣,∗\n♥Nara Institute of Science and Technology\n♠University of Cambridge\n♦The Chinese University of Hong Kong\n♣Tencent AI Lab',
 'ground_truth': "The focus of the paper 'A Survey on Retrieval-Augmented Text Generation' is to provide a comprehensive overview of the field of retrieval-augmented text generation, which has gained significant attention in the computational linguistics community. The paper highlights the generic paradigm of retrieval-augmented generation models, reviews notable approaches across various natural language processing tasks such as dialogue response generation and machine translation, and discusses the state-of-the-art performance achieved by these models. Additionally, the paper identifies and suggests important future research directions 

In [67]:
eval_dataset.to_csv("groundtruth_eval_dataset.csv")

Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 84.52ba/s]


6825

# Evaluating RAG pipeline

In [68]:
# To load the dataset from disk
# from datasets import Dataset
# eval_dataset = Dataset.from_csv("groundtruth_eval_dataset.csv")

In [84]:
from ragas.metrics import (
    answer_relevancy,
    faithfulness,
    context_recall,
    context_precision,
    context_relevancy,
    answer_correctness,
    answer_similarity
)

from ragas.metrics.critique import harmfulness
from ragas import evaluate

def create_ragas_dataset(rag_pipeline, eval_dataset):
  rag_dataset = []
  for row in tqdm(eval_dataset):
    answer = rag_pipeline.invoke( row["question"])
    rag_dataset.append(
        {"question" : row["question"],
         "answer" : answer["response"].content,
         "contexts" : [context.page_content for context in answer["context"]],
         "ground_truths" : [row["ground_truth"]]
         }
    )
  rag_df = pd.DataFrame(rag_dataset)
  rag_eval_dataset = Dataset.from_pandas(rag_df)
  return rag_eval_dataset

def evaluate_ragas_dataset(ragas_dataset):
  result = evaluate(
    ragas_dataset,
    metrics=[
        context_precision,
        faithfulness,
        answer_relevancy,
        context_recall,
        context_relevancy,
        answer_correctness,
        answer_similarity
    ],
  )
  return result

In [85]:
from tqdm import tqdm
import pandas as pd

basic_qa_ragas_dataset = create_ragas_dataset(retrieval_augmented_qa_chain, eval_dataset)

100%|██████████| 10/10 [00:14<00:00,  1.49s/it]


In [86]:
basic_qa_ragas_dataset

Dataset({
    features: ['question', 'answer', 'contexts', 'ground_truths'],
    num_rows: 10
})

In [87]:
basic_qa_ragas_dataset[0]

{'question': "What is the focus of the paper 'A Survey on Retrieval-Augmented Text Generation'?",
 'answer': "The focus of the paper 'A Survey on Retrieval-Augmented Text Generation' is to conduct a survey about retrieval-augmented text generation.",
 'contexts': ['paper aims to conduct a survey about retrieval-\naugmented text generation. It ﬁrstly highlights\nthe generic paradigm of retrieval-augmented\ngeneration, and then it reviews notable ap-\nproaches according to different tasks including',
  'A Survey on Retrieval-Augmented Text Generation\nHuayang Li♥,∗\nYixuan Su♠,∗\nDeng Cai♦,∗\nYan Wang♣,∗\nLemao Liu♣,∗\n♥Nara Institute of Science and Technology\n♠University of Cambridge\n♦The Chinese University of Hong Kong\n♣Tencent AI Lab'],
 'ground_truths': ["The focus of the paper 'A Survey on Retrieval-Augmented Text Generation' is to provide a comprehensive overview of the field of retrieval-augmented text generation, which has gained significant attention in the computational ling

In [88]:
basic_qa_ragas_dataset.to_csv("basic_qa_ragas_dataset.csv")

Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 364.47ba/s]


11040

In [89]:
basic_qa_result = evaluate_ragas_dataset(basic_qa_ragas_dataset)

evaluating with [context_precision]


100%|██████████| 1/1 [00:08<00:00,  8.23s/it]


evaluating with [faithfulness]


100%|██████████| 1/1 [00:23<00:00, 23.31s/it]


evaluating with [answer_relevancy]


100%|██████████| 1/1 [00:12<00:00, 12.19s/it]


evaluating with [context_recall]


100%|██████████| 1/1 [00:20<00:00, 20.50s/it]


evaluating with [context_relevancy]


100%|██████████| 1/1 [00:01<00:00,  1.78s/it]


evaluating with [answer_correctness]


100%|██████████| 1/1 [00:18<00:00, 18.16s/it]


evaluating with [answer_similarity]


100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


In [90]:
basic_qa_result

{'context_precision': 0.3000, 'faithfulness': 0.3000, 'answer_relevancy': 1.0000, 'context_recall': 0.8667, 'context_relevancy': 0.0664, 'answer_correctness': 0.4830, 'answer_similarity': 0.8857}

# Trying out other retrievers

In [110]:
def create_qa_chain(retriever):
  primary_qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
  retriever = RunnableParallel({"context":  retriever, "question":RunnablePassthrough()})

  created_qa_chain =retriever | { "response": prompt | primary_qa_llm,"context": itemgetter("context"), }
  

  return created_qa_chain

#### Parent Document Retriever

In [91]:
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore

parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=200)

vectorstore = Chroma(collection_name="split_parents", embedding_function=OpenAIEmbeddings())

store = InMemoryStore()

In [92]:
parent_document_retriever = ParentDocumentRetriever(
    vectorstore=vectorstore,
    docstore=store,
    child_splitter=child_splitter,
    parent_splitter=parent_splitter,
)

In [98]:
parent_document_retriever.add_documents(base_docs)


KeyboardInterrupt: 

Lets test the chain using the created retriever

In [101]:
parent_document_retriever_qa_chain = create_qa_chain(parent_document_retriever)
parent_document_retriever_qa_chain.invoke("What is RAG?")["response"].content

'RAG refers to the Hybrid Retrieval-Augmented Generation (HybridRAG) framework.'

In [102]:
pdr_qa_ragas_dataset = create_ragas_dataset(parent_document_retriever_qa_chain, eval_dataset)
pdr_qa_ragas_dataset.to_csv("pdr_qa_ragas_dataset.csv")

100%|██████████| 10/10 [00:19<00:00,  1.98s/it]
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 131.06ba/s]


60073

In [103]:
pdr_qa_result = evaluate_ragas_dataset(pdr_qa_ragas_dataset)
pdr_qa_result


evaluating with [context_precision]


100%|██████████| 1/1 [00:07<00:00,  7.85s/it]


evaluating with [faithfulness]


100%|██████████| 1/1 [00:26<00:00, 26.22s/it]


evaluating with [answer_relevancy]


100%|██████████| 1/1 [00:16<00:00, 16.88s/it]


evaluating with [context_recall]


100%|██████████| 1/1 [00:20<00:00, 20.98s/it]


evaluating with [context_relevancy]


100%|██████████| 1/1 [00:18<00:00, 18.22s/it]


evaluating with [answer_correctness]


100%|██████████| 1/1 [00:19<00:00, 19.77s/it]


evaluating with [answer_similarity]


100%|██████████| 1/1 [00:01<00:00,  1.76s/it]


{'context_precision': 0.2583, 'faithfulness': 0.4917, 'answer_relevancy': 0.9938, 'context_recall': 0.9000, 'context_relevancy': 0.0151, 'answer_correctness': 0.4939, 'answer_similarity': 0.8955}

#### Ensembe retriever

In [94]:
%pip install -q -U rank_bm25

Note: you may need to restart the kernel to use updated packages.


In [104]:
from langchain.retrievers import BM25Retriever, EnsembleRetriever

text_splitter = RecursiveCharacterTextSplitter(chunk_size=450, chunk_overlap=75)
docs = text_splitter.split_documents(base_docs)

bm25_retriever = BM25Retriever.from_documents(docs)
bm25_retriever.k = 2

embedding = OpenAIEmbeddings()
vectorstore = Chroma.from_documents(docs, embedding)
chroma_retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.75, 0.25])

In [105]:
ensemble_retriever_qa_chain = create_qa_chain(ensemble_retriever)
ensemble_retriever_qa_chain.invoke( "What is RAG?")["response"].content

'RAG stands for Retrieval Augmented Generation.'

In [106]:
ensemble_qa_ragas_dataset = create_ragas_dataset(ensemble_retriever_qa_chain, eval_dataset)
ensemble_qa_ragas_dataset.to_csv("ensemble_qa_ragas_dataset.csv")

100%|██████████| 10/10 [00:17<00:00,  1.77s/it]
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 278.86ba/s]


24719

In [107]:
ensemble_qa_result = evaluate_ragas_dataset(ensemble_qa_ragas_dataset)
ensemble_qa_result

evaluating with [context_precision]


100%|██████████| 1/1 [00:11<00:00, 11.14s/it]


evaluating with [faithfulness]


100%|██████████| 1/1 [00:13<00:00, 13.86s/it]


evaluating with [answer_relevancy]


100%|██████████| 1/1 [00:12<00:00, 12.51s/it]


evaluating with [context_recall]


100%|██████████| 1/1 [00:20<00:00, 20.92s/it]


evaluating with [context_relevancy]


100%|██████████| 1/1 [00:07<00:00,  7.11s/it]


evaluating with [answer_correctness]


100%|██████████| 1/1 [00:16<00:00, 16.43s/it]


evaluating with [answer_similarity]


100%|██████████| 1/1 [00:01<00:00,  1.57s/it]


{'context_precision': 0.7087, 'faithfulness': 0.5333, 'answer_relevancy': 0.8980, 'context_recall': 0.8000, 'context_relevancy': 0.0188, 'answer_correctness': 0.3826, 'answer_similarity': 0.8835}

In [108]:
basic_qa_result

{'context_precision': 0.3000, 'faithfulness': 0.3000, 'answer_relevancy': 1.0000, 'context_recall': 0.8667, 'context_relevancy': 0.0664, 'answer_correctness': 0.4830, 'answer_similarity': 0.8857}

In [109]:
pdr_qa_result

{'context_precision': 0.2583, 'faithfulness': 0.4917, 'answer_relevancy': 0.9938, 'context_recall': 0.9000, 'context_relevancy': 0.0151, 'answer_correctness': 0.4939, 'answer_similarity': 0.8955}