# MemoRAG: Enhancing Retrieval-Augmented Generation with Memory Models

## Overview
MemoRAG is a Retrieval-Augmented Generation (RAG) framework that incorporates a memory model as an auxiliary step before the retrieval phase. In doing so, it bridges the gap in contextual understanding and reasoning that standard RAG techniques face when addressing queries with implicit or ambiguous information needs and unstructured external knowledge.

## Motivation
Standard RAG techniques rely heavily on lexical or semantic matching between the query and the knowledge base. While this approach works well for clear question answering tasks with structured knowledge, it often falls short when handling queries with implicit or ambiguous information (e.g., describing the relationships between main characters in a novel) or when the knowledge base is unstructured (e.g., fiction books). In such cases, lexical or semantic matching seldom produces the desired outputs.

## Key Components
1. **Memory**: A compressed representation of the database created by a long-context model, designed to handle and summarize extensive inputs efficiently.
2. **Retriever**: A standard RAG retrieval model responsible for selecting relevant context from the knowledge base to support the generator.
3. **Generator**: A generative language model that produces responses by combining the query with the retrieved context, similar to standard RAG setups.

## Method Details
### 1. Memory
- The memory module serves as an auxiliary component to enhance the retriever’s ability to identify better matches between queries and relevant parts of the database. It takes the original query and the database as inputs and produces staging answers — intermediate outputs like clues, surrogate queries, or key points — which the retriever uses instead of the original query.
- Long-term memory is constructed by running a long-context model, such as Qwen2-7B-Instruct or Mistral-7B-Instruct-v0.2, over the entire database. This process generates a compressed representation of the database through an attention mechanism.
- The compressed representation is stored as key-value pairs, facilitating efficient and accurate retrieval.
- Released memory models include memorag-qwen2-7b-inst and memorag-mistral-7b-inst, derived from Qwen2-7B-Instruct and Mistral-7B-Instruct-v0.2, respectively.

### 2. Retriever
- The retriever is a standard retrieval model, adapted to take processed queries (created by the memory module as staging answers) instead of the original query.
- It outputs the retrieved **context**, which serves as the basis for generating the final answer.


### 3. Generator
- The generator produces the final response by combining the retriever’s output (retrieved context) with the original query.
- MemoRAG ensures compatibility and consistency by using the memory module’s underlying model as the default generator.

## Benefits of the Approach
1. **Extended Scope of Queries:** MemoRAG's preprocessing capabilities enable it to handle complex and long-context tasks that conventional RAG methods struggle with.

2. **Improved Accuracy:** By simplifying and adjusting queries before retrieval, MemoRAG enhances performance over standard RAG methods.

3. **Flexibility:** Adapts to diverse tasks, datasets, and retrieval scenarios.

4. **Robustness:** Improved performance remains consistent across various generators, datasets, and query types.

5. **Efficiency**: The use of key-value compression reduces computational overhead.

## Conclusion
The memory module in MemoRAG significantly enhances comprehension of both the queries and the database, enabling more effective retrieval. Its ability to preprocess queries, generate staging answers, and leverage long-context memory models ensures high-quality responses, making MemoRAG a significant step forward in the evolution of retrieval-augmented generation.


<div style="text-align: center;">

<img src="../images/memo_rag.svg" alt="MemoRAG" style="width:100%; height:auto;">
</div>

## Implementation

### Imports

In [1]:
import os
from dotenv import load_dotenv
# from typing import List
from langchain_community.embeddings import OpenAIEmbeddings
# from langchain_community.vectorstores import FAISS
# from langchain_community.document_loaders import PyPDFLoader
from openai import OpenAI
from helper_functions import *

### OpenAI Setup

In [2]:
load_dotenv()
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

### Memory Module Classes

In [10]:
class MemoryStore:
    """The MemoryStore class is a realization of the Memory Module discussed in the paper.
    Its 'memorize' method is used to create a key-value compression of the original text (database).
    This Compression can then be used by the 'get_staging_answer' method for creating the processed query to be used later for retrieval"""

    def __init__(self):
        self.embeddings = OpenAIEmbeddings()
        self.store = None
        self.processed_count = 0

    def memorize(self, document: str):
        """Process document into key-value pairs and store them"""
        print(f"Processing chunk {self.processed_count + 1}...")
        extraction_prompt = """Extract key topics and their detailed information from this text.
        For each key topic:
        1. Identify the main concept, entity, or potential question
        2. Provide the corresponding detailed information or answer

        Text:
        {document}

        Format each pair as:
        Topic: <topic>
        Details: <details>
        """

        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "Extract key-value pairs from document"},
                {"role": "user", "content": extraction_prompt.format(document=document)}
            ]
        )

        # Parse the response into key-value pairs
        pairs = self._parse_into_pairs(response.choices[0].message.content)

        # Store the pairs
        # Batch process pairs
        topics = []
        texts = []
        metadatas = []

        for topic, details in pairs:
            if not topic or not details:  # Skip empty pairs
                continue

            combined_text = f"Topic: {topic}\nDetails: {details}"
            topics.append(topic)
            texts.append(combined_text)
            metadatas.append({"topic": topic})

        if self.store is None:
            self.store = FAISS.from_texts(texts, self.embeddings, metadatas=metadatas)
        else:
            existing_docs = self.store.similarity_search("")
            existing_topics = {doc.metadata.get("topic") for doc in existing_docs}

            # Filter out duplicates
            new_texts = []
            new_metadatas = []
            for text, metadata in zip(texts, metadatas):
                if metadata["topic"] not in existing_topics:
                    new_texts.append(text)
                    new_metadatas.append(metadata)

            if new_texts:
                self.store.add_texts(new_texts, metadatas=new_metadatas)

        self.processed_count += 1
        print(f"Processed {self.processed_count} chunks so far")

    def _parse_into_pairs(self, text: str):
        """Parse GPT response into list of (topic, details) pairs"""
        pairs = []
        lines = text.split('\n')
        current_topic = None
        current_details = []

        for line in lines:
            if line.startswith('Topic:'):
                if current_topic:  # Save previous pair
                    pairs.append((current_topic, ' '.join(current_details)))
                current_topic = line[6:].strip()
                current_details = []
            elif line.startswith('Details:'):
                current_details.append(line[8:].strip())

        # Add the last pair
        if current_topic:
            pairs.append((current_topic, ' '.join(current_details)))

        return pairs

    def get_staging_answer(self, query: str) -> str:
        """Generate staging answer y = Θ_mem(q, D | θ_mem)
        This should provide rough clues/outline to guide context retrieval"""

        if not self.store:
            return None

        results = self.store.similarity_search_with_score(query, k=5)
        relevant_info = [
            f"Topic: {doc.metadata['topic']}\nDetails: {doc.page_content}"
            for doc, _ in results
        ]

        prompt = f"""Based on available information, generate a rough outline/staging answer.
        This should help guide retrieval of detailed context, but doesn't need to be fully accurate.

        Query: {query}

        Relevant Information:
        {relevant_info}

        Generate a rough outline that could help locate the correct answers:"""

        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "Generate rough outlines to guide information retrieval"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.7  # Allow for some creativity in generating clues
        )

        return response.choices[0].message.content

    def save_store(self, path: str):
        if self.store:
            self.store.save_local(path)

    def load_store(self, path: str):
        self.store = FAISS.load_local(path, self.embeddings, allow_dangerous_deserialization=True)

### Retrieval Function

In [4]:
def retrieve_context(staging_answer: str, vectorstore) -> List[str]:
    """Retrieve relevant context using the staging answer and the database vectorstore.
    Implements c = Γ(y, D | γ) from the paper"""

    results = vectorstore.similarity_search(staging_answer, k=5)
    contexts = [doc.page_content for doc in results]

    return contexts

### Generation Function

In [5]:
def generate_answer(query: str, contexts: List[str]) -> str:
    """Generate final answer y = Θ(q, c | θ)"""
    prompt = f"""Based on the provided context, answer the query.

Query: {query}

Retrieved Information:
{' '.join(contexts)}

Provide a clear and concise answer focusing only on the retrieved information.
"""

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a knowledgeable assistant. Provide clear, concise answers."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=200,
        temperature=0.7
    )

    return response.choices[0].message.content

### Query Processing Function

In [6]:
def process_query(query: str, memory_store, vectorstore):
    print("\nProcessing Query:", query)
    print("=" * 50)

    # y = Θ_mem(q, D | θ_mem)
    staging_answer = memory_store.get_staging_answer(query)
    print(f"Staging Answer:\n{staging_answer}")

    # c = Γ(y, D | γ)
    contexts = retrieve_context(staging_answer, vectorstore)
    print(f"Retrieved Context Example: {contexts[0]}")

    # y = Θ(q, c | θ)
    final_answer = generate_answer(query, contexts)
    print(f"Final Answer: {final_answer}")

    return contexts, final_answer

### Initialize Components

In [None]:
# Initialize memory store
memory_store = MemoryStore()

# Load and process document
path = "../data/Understanding_Climate_Change.pdf"
loader = PyPDFLoader(path)
documents = loader.load()
document_text = '\n'.join([doc.page_content for doc in documents])
memory_store.memorize(document_text)
chunks_vector_store = encode_pdf(path, chunk_size=1000, chunk_overlap=200)

### Usage Examples

In [None]:
query_1 = "What are the impacts of climate change on biodiversity?"
query_2 = "Please summarize the climate change article"
query_3 = "Describe the social and economic influence of climate change."

for query in [query_1, query_2, query_3]:
    process_query(query, memory_store, chunks_vector_store)

## Comparison

### Short Context

#### Simple RAG

In [7]:
from evaluation.evalute_rag import *



In [None]:
chunks_query_retriever = chunks_vector_store.as_retriever(search_kwargs={"k": 2})

In [None]:
evaluate_rag(chunks_query_retriever)

#### MemoRAG

In [None]:
def evaluate_memo_rag(num_questions: int = 5) -> None:
    """
    Evaluate the MemoRAG system using predefined metrics.

    Args:
        num_questions (int): Number of questions to evaluate (default: 5).
    """
    q_a_file_name = "../data/q_a.json"
    with open(q_a_file_name, "r", encoding="utf-8") as json_file:
        q_a = json.load(json_file)

    questions = [qa["question"] for qa in q_a][:num_questions]
    ground_truth_answers = [qa["answer"] for qa in q_a][:num_questions]
    generated_answers = []
    retrieved_documents = []

    # Generate answers and retrieve documents for each question
    for question in questions:
        contexts, result = process_query(question, memory_store, chunks_vector_store)
        retrieved_documents.append(contexts)
        generated_answers.append(result)

    # Create test cases and evaluate
    test_cases = create_deep_eval_test_cases(questions, ground_truth_answers, generated_answers, retrieved_documents)
    evaluate(
        test_cases=test_cases,
        metrics=[correctness_metric, faithfulness_metric, relevance_metric]
    )

In [None]:
evaluate_memo_rag()

### Long Context

In [8]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Initialize memory store
memory_store = MemoryStore()

# Load and process document
path = "../data/The_Wealth_of_Nations_Project_Gutenberg.pdf"
loader = PyPDFLoader(path)
documents = loader.load()
document_text = '\n'.join([doc.page_content for doc in documents])
# Split document into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=50000,
    chunk_overlap=5000
)

In [27]:
# chunks = text_splitter.split_text(document_text)
#
# # Memorize each chunk
# for chunk in chunks:
#     memory_store.memorize(chunk)

Processing chunk 1...
Processed 1 chunks so far
Processing chunk 2...
Processed 2 chunks so far
Processing chunk 3...
Processed 3 chunks so far
Processing chunk 4...
Processed 4 chunks so far
Processing chunk 5...
Processed 5 chunks so far
Processing chunk 6...
Processed 6 chunks so far
Processing chunk 7...
Processed 7 chunks so far
Processing chunk 8...
Processed 8 chunks so far
Processing chunk 9...
Processed 9 chunks so far
Processing chunk 10...
Processed 10 chunks so far
Processing chunk 11...
Processed 11 chunks so far
Processing chunk 12...
Processed 12 chunks so far
Processing chunk 13...
Processed 13 chunks so far
Processing chunk 14...
Processed 14 chunks so far
Processing chunk 15...
Processed 15 chunks so far
Processing chunk 16...
Processed 16 chunks so far
Processing chunk 17...
Processed 17 chunks so far
Processing chunk 18...
Processed 18 chunks so far
Processing chunk 19...
Processed 19 chunks so far
Processing chunk 20...
Processed 20 chunks so far
Processing chunk 2

In [11]:
# memory_store.save_store("../data/The_Wealth_of_Nations_Project_Gutenberg_Memory_Store.faiss")
memory_store.load_store("../data/The_Wealth_of_Nations_Project_Gutenberg_Memory_Store.faiss")

ValueError: The de-serialization relies loading a pickle file. Pickle files can be modified to deliver a malicious payload that results in execution of arbitrary code on your machine.You will need to set `allow_dangerous_deserialization` to `True` to enable deserialization. If you do this, make sure that you trust the source of the data. For example, if you are loading a file that you created, and know that no one else has modified the file, then this is safe to do. Do not set this to `True` if you are loading a file from an untrusted source (e.g., some random site on the internet.).

In [30]:
chunks_vector_store = encode_pdf(path, chunk_size=1000, chunk_overlap=200)

#### Simple RAG

In [35]:
chunks_query_retriever = chunks_vector_store.as_retriever(search_kwargs={"k": 2})

In [42]:
evaluate_rag_long_context(chunks_query_retriever)

ImportError: cannot import name 'evaluate_rag_long_context' from 'evaluation.evalute_rag' (/Users/erantrabelci/workspace/RAG_Techniques/evaluation/evalute_rag.py)

#### MemoRAG

In [33]:
def evaluate_memo_rag_long_context(num_questions: int = 5) -> None:
    """
    Evaluate the MemoRAG system using predefined metrics.

    Args:
        num_questions (int): Number of questions to evaluate (default: 5).
    """
    q_a_file_name = "../data/q_a_smith.json"
    with open(q_a_file_name, "r", encoding="utf-8") as json_file:
        q_a = json.load(json_file)

    questions = [qa["question"] for qa in q_a][:num_questions]
    ground_truth_answers = [qa["answer"] for qa in q_a][:num_questions]
    generated_answers = []
    retrieved_documents = []

    # Generate answers and retrieve documents for each question
    for question in questions:
        contexts, result = process_query(question, memory_store, chunks_vector_store)
        retrieved_documents.append(contexts)
        generated_answers.append(result)

    # Create test cases and evaluate
    test_cases = create_deep_eval_test_cases(questions, ground_truth_answers, generated_answers, retrieved_documents)
    evaluate(
        test_cases=test_cases,
        metrics=[correctness_metric, faithfulness_metric, relevance_metric]
    )

In [34]:
evaluate_memo_rag_long_context()


Processing Query: Who is the author of 'The Wealth of Nations'?
Staging Answer:
I. Introduction
    A. Query: Who is the author of 'The Wealth of Nations'?
    
II. Overview of 'The Wealth of Nations'
    A. Author: Adam Smith
    B. Key themes and topics covered in the book
        1. Division of labor
        2. Role of markets
        3. Impact on productivity and economic development
        4. Wealth and the role of money
        5. Philosophers' System on Agriculture
        6. Commerce and Government

III. Authorship and Contribution
    A. Adam Smith's role as the author of 'The Wealth of Nations'
    B. Significance of the book in economics and philosophy

IV. Conclusion
    A. Summary of key points
    B. Clarification of the author of 'The Wealth of Nations'
Retrieved Context Example: *** ST ART OF THE PROJECT  GUTENBERG EBOOK 3300 ***
An Inquiry into
the Nature and
Causes of the
Wealth of
Nations
by Adam Smith
Contents
INTRODUCTION AND PLAN OF THE WORK.
BOOK I. OF  THE CAU

Event loop is already running. Applying nest_asyncio patch to allow async execution...


Evaluating 5 test case(s) in parallel: |          |  0% (0/5) [Time Taken: 00:00, ?test case/s]ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 2 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded. Retrying: 1 time(s)...
ERROR:root:OpenAI rate limit exceeded



Metrics Summary

  - ✅ Correctness (GEval) (score: 0.7334047947552913, threshold: 0.5, strict: False, evaluation model: gpt-4o, reason: Actual output correctly identifies land and capital, but uses 'capital stock' and 'productive labor' instead of 'labor'., error: None)
  - ✅ Faithfulness (score: 1.0, threshold: 0.7, strict: False, evaluation model: gpt-4, reason: None, error: None)
  - ❌ Contextual Relevancy (score: 0.0, threshold: 1.0, strict: False, evaluation model: gpt-4, reason: The score is 0.00 because the context does not mention the three factors of production, as highlighted in the following feedback: 'The context does not mention anything about the 'three factors of production'.', error: None)

For test case:

  - input: What are the three factors of production mentioned?
  - actual output: The three factors of production mentioned are land, capital stock, and productive labor.
  - expected output: Land, labor, and capital.
  - context: None
  - retrieval context: ['Land 


