# Document Augmentation through Question Generation for Enhanced Retrieval

## Overview

This implementation demonstrates a text augmentation technique that leverages additional question generation to improve document retrieval within a vector database. By generating and incorporating various questions related to each text fragment, the system enhances the standard retrieval process, thus increasing the likelihood of finding relevant documents that can be utilized as context for generative question answering.

## Motivation

By enriching text fragments with related questions, we aim to significantly enhance the accuracy of identifying the most relevant sections of a document that contain answers to user queries.

## Prerequisites

This approach requires a local large language model, such as Mistral or Llama 3.1, running within the Ollama framework (https://ollama.com). Alternatively, you can use the OpenAI API, but it will require minor modifications to this notebook.

## Key Components

1.	<b>PDF Processing and Text Chunking</b>: Handling PDF documents and dividing them into manageable text fragments.
2.	<b>Question Augmentation</b>: Generating relevant questions at both the document and fragment levels using a generative model.
3.	<b>Vector Store Creation</b>: Calculating embeddings for documents and creating a FAISS vector store.
4.	<b>Retrieval and Answer Generation</b>: Finding the most relevant document using FAISS and generating answers based on the context provided.

## Method Details

### Document Preprocessing

1.	Convert the PDF to a string.
2.	Split the text into overlapping text documents (text_document) for building context purpose and then each document to overlapping text fragments (text_fragment) for retrieval and semantic search purpose.

### Document Augmentation

1.	Generate questions at the document or text fragment level.
2.	Configure the number of questions to generate using the QUESTIONS_PER_DOCUMENT constant.

### Vector Store Creation

1.	Use the OllamaEmbeddings class with models hosted on the Ollama framework to compute document embeddings.
2.	Create a FAISS vector store from these embeddings.

### Retrieval and Generation

1.	Retrieve the most relevant document from the FAISS store based on the given query.
2.	Use the retrieved document as context for generating answers with a generative model.

## Benefits of This Approach

1.	Enhanced Retrieval Process: Increases the probability of finding the most relevant FAISS document for a given query.
2.	Flexible Context Adjustment: Allows for easy adjustment of the context window size for both text documents and fragments.

## Conclusion

This technique provides a method to improve the quality of information retrieval in vector-based document search systems. By generating additional questions similar to user queries, it potentially leads to better comprehension and more accurate responses in subsequent tasks, such as question answering.

### Import libraries and set constants

In [21]:
import sys
import os
from langchain.docstore.document import Document
from langchain.vectorstores import FAISS
from enum import Enum

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks

from helper_functions import *

MODEL = "mistral"
#MODEL = "llama3.1:8b"

class QuestionGeneration(Enum):
    """
    Enum class to specify the level of question generation for document processing.

    Attributes:
        DOCUMENT_LEVEL (int): Represents question generation at the entire document level.
        FRAGMENT_LEVEL (int): Represents question generation at the individual text fragment level.
    """
    DOCUMENT_LEVEL = 1
    FRAGMENT_LEVEL = 2

#Depending on the model, for Mitral 7B it can be max 8000, for Llama 3.1 8B 128k
DOCUMENT_MAX_TOKENS = 4000
DOCUMENT_OVERLAP_TOKENS = 100

#Embeddings and text similarity calculated on shorter texts
FRAGMENT_MAX_TOKENS = 128
FRAGMENT_OVERLAP_TOKENS = 16

#Questions generated on document or fragment level
QUESTION_GENERATION = QuestionGeneration.DOCUMENT_LEVEL
#how many questions will be generated for specific document or fragment
QUESTIONS_PER_DOCUMENT = 40

### Define classes and functions used by this pipeline

In [22]:
import ollama
import re
import os
from langchain.embeddings.base import Embeddings
from typing import Any

class OllamaEmbeddings(Embeddings):
    """
    A class for generating embeddings using the Ollama API.

    This class extends the `Embeddings` base class to provide functionality for generating
    embeddings for both documents and queries using the Ollama API. It includes methods for
    embedding multiple documents, a single query, and also allows the class instance to be
    used as a callable function to embed queries.

    Attributes:
        model (str): The model name or identifier to be used with the Ollama API.
    """
    
    def __init__(self):
        """
        Initializes the OllamaEmbeddings class.

        This constructor sets up the instance without any specific initialization at this time.
        """
        pass

    def embed_documents(self, docs: list[str]) -> list[list[float]]:
        """
        Generates embeddings for a list of documents using the Ollama API.

        Args:
            docs (list of str): A list of document strings to be embedded.

        Returns:
            list of list of float: A list of embeddings, where each embedding is a list of floats
                                    corresponding to the input documents.
        """
        response = ollama.embed(model=MODEL, input=docs)
        return response['embeddings']
    
    def embed_query(self, query: str) -> list[float]:
        """
        Generates an embedding for a single query using the Ollama API.

        Args:
            query (str): The query string to be embedded.

        Returns:
            list of float: The embedding for the query as a list of floats.
        """
        response = ollama.embed(model=MODEL, input=[query])
        return response['embeddings'][0]
    
    def __call__(self, query: str) -> list[float]:
        """
        Allows the instance to be used as a callable to generate an embedding for a query.

        This method provides a convenient way to get the embedding for a query without directly
        calling `embed_query`.

        Args:
            query (str): The query string to be embedded.

        Returns:
            list of float: The embedding for the query as a list of floats.
        """
        return self.embed_query(query)
    
def clean_and_filter_questions(questions: List[str]) -> List[str]:
    """
    Cleans and filters a list of questions by removing leading numbers and ensuring each
    question ends with a question mark.

    This function processes each question in the input list by removing any leading digits and
    optional periods or spaces. It then checks if the resulting string ends with a question mark,
    appending it to the results list if it does.

    Args:
        questions (List[str]): A list of questions as strings that need to be cleaned and filtered.

    Returns:
        List[str]: A list of cleaned and filtered questions that end with a question mark.
    """
    cleaned_questions = []
    for question in questions:
        cleaned_question = re.sub(r'^\d+\.\s*', '', question.strip())
        if cleaned_question.endswith('?'):
            cleaned_questions.append(cleaned_question)
    return cleaned_questions

def generate_questions(text: str) -> List[str]:
    """
    Generates a list of questions based on the provided text using the Ollama API.

    This function sends a prompt to the Ollama API to generate a list of all possible questions
    that can be asked about the given context. The questions should be directly answerable from
    the provided context and should not include any answers or headers.

    Args:
        text (str): The context data from which questions are generated.

    Returns:
        List[str]: A list of unique, filtered questions extracted from the API response.
    """
    output = ollama.generate(
        model=MODEL,
        prompt=f'Using the context data: {text}, generate a list of at least {QUESTIONS_PER_DOCUMENT} possible questions that can be asked about this context. Ensure the questions are directly answerable within the context and do not include any answers or headers. Separate the questions with a new line character'
    )
    result = output['response']
    questions = result.split(os.linesep)
    filtered_questions = clean_and_filter_questions(questions)
    unique_questions = list(set(filtered_questions))
    return unique_questions

def generate_answer(content: str, question: str) -> str:
    """
    Generates an answer to a given question based on the provided context using the Ollama API.

    This function sends a prompt to the Ollama API to generate a precise answer to the specified
    question using the provided context. The answer is extracted from the API response.

    Args:
        content (str): The context data used to generate the answer.
        question (str): The question for which the answer is generated.

    Returns:
        str: The precise answer to the question based on the provided context.
    """
    output = ollama.generate(
        model=MODEL,
        prompt=f'Using the context data: {content}, provide a brief and precise answer to the question: {question}.'
    )
    result = output['response']
    return result

def split_document(document: str, chunk_size: int, chunk_overlap: int) -> List[str]:
    """
    Splits a document into smaller chunks of text based on the specified chunk size and overlap.

    This function tokenizes the input document into words and then divides it into chunks of a
    specified size with a defined overlap between consecutive chunks. The chunks are then joined
    back into strings.

    Args:
        document (str): The text of the document to be split.
        chunk_size (int): The size of each chunk in terms of the number of tokens.
        chunk_overlap (int): The number of overlapping tokens between consecutive chunks.

    Returns:
        List[str]: A list of text chunks, where each chunk is a string of the document content.
    """
    #We apply here simple tokenization
    tokens = re.findall(r'\b\w+\b', document)
    chunks = []
    for i in range(0, len(tokens), chunk_size - chunk_overlap):
        chunk_tokens = tokens[i:i + chunk_size]
        chunks.append(chunk_tokens)
        if i + chunk_size >= len(tokens):
            break
    subtexts = [" ".join(chunk) for chunk in chunks]
    return subtexts

def print_document(comment: str, document: Any) -> None:
    """
    Prints a comment followed by the content of a document.

    This function prints a given comment and then the content of the document. The document's content
    can be of any type, but it will be converted to a string for printing purposes.

    Args:
        comment (str): The comment or description to print before the document details.
        document (Any): The document whose content is to be printed. It can be of any type that is
                     convertible to a string.

    Returns:
        None: This function does not return any value.
    """
    print(f'{comment} (type: {document.metadata["type"]}, index: {document.metadata["index"]}): {document.page_content}')

### Main pipeline

In [23]:
#Load sample PDF document to string variable
path = "../data/Understanding_Climate_Change.pdf"
content = read_pdf_to_string(path)

#Instantiate Ollama Embeddings class that will be used by FAISS
embedding_model = OllamaEmbeddings()

#Split the whole text content into text documents not longer than DOCUMENT_MAX_TOKENS tokens
#and with DOCUMENT_OVERLAP_TOKENS overlapping
text_documents = split_document(content, DOCUMENT_MAX_TOKENS, DOCUMENT_OVERLAP_TOKENS)

print(f'Text content split into: {len(text_documents)} documents')

documents = []
counter = 0
for i, text_document in enumerate(text_documents):
    text_fragments = split_document(text_document, FRAGMENT_MAX_TOKENS, FRAGMENT_OVERLAP_TOKENS)
    text_fragments_len = len(text_fragments)
    print(f'Text document {i} - split into: {text_fragments_len} fragments')
    for j, text_fragment in enumerate(text_fragments):
        document = Document(page_content=text_fragment, metadata={"type": "ORIGINAL ", "index": counter, "text": text_document})
        documents.append(document)
        counter += 1
        
        if QUESTION_GENERATION == QuestionGeneration.FRAGMENT_LEVEL:
            questions = generate_questions(text_fragment)
            for question in questions:
                document = Document(page_content=question, metadata={"type": "AUGMENTED", "index": counter, "text": text_document})
                documents.append(document)
                counter += 1
            print(f'Text document {i} Text fragment {j} - generated: {len(questions)} questions')
    
    if QUESTION_GENERATION == QuestionGeneration.DOCUMENT_LEVEL:
        questions = generate_questions(text_document)
        for question in questions:
            document = Document(page_content=question, metadata={"type": "AUGMENTED", "index": counter, "text": text_document})
            documents.append(document)
            counter += 1
        print(f'Text document {i} - generated: {len(questions)} questions')

for document in documents:
    print_document("Dataset", document)

print(f'Creating store, calculating embeddings for {len(documents)} FAISS documents')
vectorstore = FAISS.from_documents(documents, embedding_model)

print("Creating retriever returning the most relevant FAISS document")
document_query_retriever = vectorstore.as_retriever(search_kwargs={"k": 1})

Text content split into: 3 documents
Text document 0 - split into: 36 fragments
Text document 0 - generated: 50 questions
Text document 1 - split into: 36 fragments
Text document 1 - generated: 40 questions
Text document 2 - split into: 15 fragments
Text document 2 - generated: 40 questions
Dataset (type: ORIGINAL , index: 0): Understanding Climate Change Chapter 1 Introduction to Climate Change Climate change refers to significant long term changes in the global climate The term global climate encompasses the planet s overall weather patterns including temperature precipitation and wind patterns over an extended period Over the past century human activities particularly the burning of fossil fuels and deforestation have significantly contributed to climate change Historical Context The Earth s climate has changed throughout history Over the past 650 000 years there have been seven cycles of glacial advance and retreat with the abrupt end of the last ice age about 11 700 years ago mark

### Find the most relevant FAISS document in the store. In most cases, this will be an augmented question rather than the original text document.

In [24]:
query = "How do freshwater ecosystems change due to alterations in climatic factors?"
print (f'Question:{os.linesep}{query}{os.linesep}')
retrieved_documents = document_query_retriever.invoke(query)

for doc in retrieved_documents:
    print_document("Relevant fragment retrieved", doc)

Question:
How do freshwater ecosystems change due to alterations in climatic factors?

Relevant fragment retrieved (type: AUGMENTED, index: 39): How are freshwater ecosystems affected by changes in precipitation patterns, temperature, and water flow?


### Find the parent text document and use it as context for the generative model to generate an answer to the question.

In [25]:
context = doc.metadata['text']
print (f'{os.linesep}Context:{os.linesep}{context}')
answer = generate_answer(context, query)
print(f'{os.linesep}Answer:{os.linesep}{answer}')


Context:

Answer:
 Freshwater ecosystems, which include rivers, lakes, and wetlands, are affected by changes in precipitation patterns, temperature, and water flow. These changes can lead to altered water quality, habitat loss, and reduced biodiversity. Freshwater species like fish and amphibians are particularly at risk due to these alterations. Strategies for conservation include establishing and managing protected areas to preserve these fragile ecosystems and promote their resilience against the impacts of climate change.
