# MemoRAG: Enhancing Retrieval-Augmented Generation with Memory Models

## Overview
MemoRAG is a Retrieval-Augmented Generation (RAG) framework that incorporates a memory model as an auxiliary step before the retrieval phase. In doing so, it bridges the gap in contextual understanding and reasoning that standard RAG techniques face when addressing queries with implicit or ambiguous information needs and unstructured external knowledge.

## Motivation
Standard RAG techniques rely heavily on lexical or semantic matching between the query and the knowledge base. While this approach works well for clear question answering tasks with structured knowledge, it often falls short when handling queries with implicit or ambiguous information (e.g., describing the relationships between main characters in a novel) or when the knowledge base is unstructured (e.g., fiction books). In such cases, lexical or semantic matching seldom produces the desired outputs.

## Key Components
1. **Memory**: A compressed representation of the database created by a long-context model, designed to handle and summarize extensive inputs efficiently.
2. **Retriever**: A standard RAG retrieval model responsible for selecting relevant context from the knowledge base to support the generator.
3. **Generator**: A generative language model that produces responses by combining the query with the retrieved context, similar to standard RAG setups.

## Method Details
### 1. Memory
- The memory module serves as an auxiliary component to enhance the retriever’s ability to identify better matches between queries and relevant parts of the database. It takes the original query and the database as inputs and produces staging answers — intermediate outputs like clues, surrogate queries, or key points — which the retriever uses instead of the original query.
- Long-term memory is constructed by running a long-context model, such as Qwen2-7B-Instruct or Mistral-7B-Instruct-v0.2, over the entire database. This process generates a compressed representation of the database through an attention mechanism.
- The compressed representation is stored as key-value pairs, facilitating efficient and accurate retrieval.
- Released memory models include memorag-qwen2-7b-inst and memorag-mistral-7b-inst, derived from Qwen2-7B-Instruct and Mistral-7B-Instruct-v0.2, respectively.

### 2. Retriever
- The retriever is a standard retrieval model, adapted to take processed queries (created by the memory module as staging answers) instead of the original query.
- It outputs the retrieved **context**, which serves as the basis for generating the final answer.


### 3. Generator
- The generator produces the final response by combining the retriever’s output (retrieved context) with the original query.
- MemoRAG ensures compatibility and consistency by using the memory module’s underlying model as the default generator.

## Benefits of the Approach
1. **Extended Scope of Queries:** MemoRAG's preprocessing capabilities enable it to handle complex and long-context tasks that conventional RAG methods struggle with.

2. **Improved Accuracy:** By simplifying and adjusting queries before retrieval, MemoRAG enhances performance over standard RAG methods.

3. **Flexibility:** Adapts to diverse tasks, datasets, and retrieval scenarios.

4. **Robustness:** Improved performance remains consistent across various generators, datasets, and query types.

5. **Efficiency**: The use of key-value compression reduces computational overhead.

## Conclusion
The memory module in MemoRAG significantly enhances comprehension of both the queries and the database, enabling more effective retrieval. Its ability to preprocess queries, generate staging answers, and leverage long-context memory models ensures high-quality responses, making MemoRAG a significant step forward in the evolution of retrieval-augmented generation.


<div style="text-align: center;">

<img src="../images/memo_rag.svg" alt="MemoRAG" style="width:100%; height:auto;">
</div>

## Implementation

### Imports

In [8]:
import os
from typing import Any
from weakref import finalize

from pydantic import BaseModel, Field, validator
from typing import List
from appdirs import system
from dotenv import load_dotenv
# from typing import List
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_experimental.graph_transformers.llm import system_prompt
# from langchain_community.vectorstores import FAISS
# from langchain_community.document_loaders import PyPDFLoader
from openai import OpenAI
from helper_functions import *
from semantic_text_splitter import TextSplitter
from itertools import chain



### OpenAI Setup

In [9]:
load_dotenv()
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

### Memory Module Classes

In [10]:
class MemoryStore:
    """The MemoryStore class is a realization of the Memory Module discussed in the paper.
    Its 'memorize' method is used to create a prompt cache mimicking a key-value compression of the original text (database).
    This cache can then be used by the 'create_retrieval_prompt' method for creating the processed query to be used later for retrieval"""

    def __init__(self):
        self.embeddings = OpenAIEmbeddings()
        self.store = None
        self.processed_count = 0

    def memorize(self, document: str):
        """Process document into key-value pairs and store them"""

        # self.reset()
        # batch_size = 10
        print(f' doc head: {document[0:300]}')

        # system_prompt = """\
        #     You are a highly capable assistant with expertise in reading comprehension
        #     and information extraction. You have been given an article, and you will be asked to perform tasks or answer questions based on that text.
        #
        #     Carefully study the content provided. Use only the information you find within
        #     the text or your general knowledge (when relevant and consistent with the text).
        #     Do not fabricate details. If you do not have enough information to answer a question or complete a task, acknowledge this clearly.
        #
        #     Communicate your answers or completed tasks in a concise, straightforward manner.
        #     """
        system_prompt = "Extract key-value pairs from document"

        kv_cache_prompt = """
        You are provided with a long article, chunk by chunk. Read each chunk carefully and extract key topics and their detailed information.
        For each key topic:
        1. Identify the main concept, entity, or potential question
        2. Provide the corresponding detailed information or answer

        Note: the aim is to mimic kv cache memory creation

        Now, the article begins:
        {document}

        The article ends here.

        FORMAT IS IMPORTANT. Each pair MUST be formatted EXACTLY as:
        Topic: <topic>
        Details: <details>

        Ensure there is a blank line between each pair.
        """

        print(f"Processing chunk {self.processed_count + 1}...")

        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": kv_cache_prompt.format(document=document)}
            ]
        )
        print(f'memorize response head: {response}')
       # Parse the response into key-value pairs
        pairs = self._parse_into_pairs(response.choices[0].message.content)

        print(f'first pairs: {pairs[0:5]}')

        # Batch process pairs
        texts = []
        metadatas = []

        for topic, details in pairs:
            if not topic or not details:  # Skip empty pairs
                continue

            combined_text = f"Topic: {topic}\nDetails: {details}"
            texts.append(combined_text)
            metadatas.append({"topic": topic})
        print(len(texts), len(metadatas))
        if self.store is None:
            self.store = FAISS.from_texts(texts, self.embeddings, metadatas=metadatas)
        else:
            existing_docs = self.store.similarity_search("")
            existing_topics = {doc.metadata.get("topic") for doc in existing_docs}

            # Filter out duplicates
            new_texts = []
            new_metadatas = []
            for text, metadata in zip(texts, metadatas):
                if metadata["topic"] not in existing_topics:
                    new_texts.append(text)
                    new_metadatas.append(metadata)

            if new_texts:
                self.store.add_texts(new_texts, metadatas=new_metadatas)

        self.processed_count += 1
        print(f"Processed {self.processed_count} chunks so far")

    def _parse_into_pairs(self, text: str):
        """Parse GPT response into list of (topic, details) pairs"""
        pairs = []
        lines = text.split('\n')
        current_topic = None
        current_details = []

        for line in lines:
            if line.startswith('Topic:'):
                if current_topic:  # Save previous pair
                    pairs.append((current_topic, ' '.join(current_details)))
                current_topic = line[6:].strip()
                current_details = []
            elif line.startswith('Details:'):
                current_details.append(line[8:].strip())

        # Add the last pair
        if current_topic:
            pairs.append((current_topic, ' '.join(current_details)))

        return pairs

    def create_retrieval_queries(self, query: str) -> str | None:
        """Generate staging answers y = Θ_mem(q, D | θ_mem)
        This should provide rough clues/outline to guide context retrieval"""

        if not self.store:
            return None

        results = self.store.similarity_search_with_score(query, k=min(10, self.store.index.ntotal)) # consider increasing k to more than 10. maybe try even top 30% or something
        relevant_info = [
            f"Topic: {doc.metadata['topic']}\nDetails: {doc.page_content}"
            for doc, _ in results
        ]

        memorag_span_prompt = """
            You are given a question related to an article. To answer it effectively, you need to use specific details from the article. You are not provided with the whole article. Instead, you are provided with specific and relevant information from it.
             Your task is to identify and extract one or more specific clue texts from the provided information that are relevant to the question.

            ### Question: {question}
            ### Information: {relevant_info}
            ### Instructions:
            1. You have a general understanding of the provided information. Your task is to generate one or more specific clues that will help in searching for supporting evidence within the article.
            2. The clues are in the form of text spans that will assist in answering the question.
            3. Only output the clues. If there are multiple clues, separate them with a newline.
            """
        memorag_sur_prompt = """
            You are given a question related to an article. To answer it effectively, you need to use specific details from the article.  You are not provided with the whole article. Instead, you are provided with specific and relevant information from it.
            Your task is to generate precise clue questions that can help locate the necessary information for answering the question.

            ### Question: {question}
            ### Information: {relevant_info}
            ### Instructions:
            1. You have a general understanding of the provided information. Your task is to generate one or more specific clues that will help in searching for supporting evidence within the article.
            2. The clues are in the form of precise surrogate questions that clarify the original question.
            3. Only output the clues. If there are multiple clues, separate them with a newline.
            """

        system_prompt = "" ## need to define this

        text_spans = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": memorag_span_prompt.format(question=query, relevant_info = relevant_info)}
            ]
        ).choices[0].message.content
        surrogate_queries = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": memorag_sur_prompt.format(question=query, relevant_info = relevant_info)}
            ],
            temperature=0.5  # Allow for some creativity in generating clues
        ).choices[0].message.content

        retrieval_query = text_spans.split("\n") + surrogate_queries.split("\n")
        print(20*'*')
        print(f'retrieval query all: {retrieval_query}')
        retrieval_query = [q for q in retrieval_query if len(q.split()) > 3]
        print(20*'*')
        print(f'retrieval query no shorts: {retrieval_query}')
        retrieval_query.append(query)
        print(20*'*')
        print(f'retrieval query with original query: {retrieval_query}')

        return retrieval_query

    def save_store(self, path: str):
        if self.store:
            self.store.save_local(path)

    def load_store(self, path: str):
        self.store = FAISS.load_local(path, self.embeddings, allow_dangerous_deserialization=True)

### Retrieval Function

In [11]:
def retrieve_context(retrieval_query: str, vectorstore) -> List[str]:
    """Retrieve relevant context using the staging answer and the database vectorstore.
    Implements c = Γ(y, D | γ) from the paper"""

    results = vectorstore.similarity_search(retrieval_query, k=5)
    contexts = [doc.page_content for doc in results]

    return contexts

### Generation Function

In [12]:
def generate_answer(query: str, contexts: List[str]) -> str:
    """Generate final answer y = Θ(q, c | θ)"""
    prompt = f"""Based on the provided context, answer the query.

Query: {query}

Retrieved Information:
{' '.join(contexts)}

Provide a clear and concise answer focusing only on the retrieved information.
"""

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": "You are a knowledgeable assistant. Provide clear, concise answers."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=500
    )

    return response.choices[0].message.content

### Query Processing Function

In [13]:
def process_query(query: str, memory_store, vectorstore):
    print("\nProcessing Query:", query)
    print("=" * 50)

    # y = Θ_mem(q, D | θ_mem)
    retrieval_queries = memory_store.create_retrieval_queries(query)
    print(f"Retrieval Queries:\n{retrieval_queries}")

    # c = Γ(y, D | γ)
    contexts = [retrieve_context(retrieval_query, vectorstore) for retrieval_query in retrieval_queries][0]
    print(f"Retrieved Context Example: {contexts[0]}")

    # y = Θ(q, c | θ)
    final_answer = generate_answer(query, contexts)
    print(f"\nFinal Answer: {final_answer}")

    return contexts, final_answer

### Initialize Components

In [14]:
# Initialize memory store
memory_store = MemoryStore()

# Load and process document
path = "../data/Understanding_Climate_Change.pdf"
loader = PyPDFLoader(path)
documents = loader.load()
document_text = '\n'.join([doc.page_content for doc in documents])
memory_store.memorize(document_text)
chunks_vector_store = encode_pdf(path, chunk_size=1000, chunk_overlap=200)

 doc head: Understanding Climate Change  
Chapter 1: Introduction to Climate Change  
Climate change refers to significant, long -term changes in the global climate. The term 
"global climate" encompasses the planet's overall weather patterns, including temperature, 
precipitation, and wind patterns, over an e
Processing chunk 1...
memorize response head: ChatCompletion(id='chatcmpl-B6S77kcpaFTD9nTkdnHeyKYgNVrNx', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="``` \nTopic: Understanding Climate Change \nDetails: Climate change refers to significant, long-term changes in the global climate caused mainly by human activities, such as burning fossil fuels and deforestation. The Earth's climate has changed historically, but modern changes, particularly rapid temperature increases, are linked to greenhouse gas emissions.\n\nTopic: Historical Context \nDetails: The Earth's climate has undergone numerous changes over the past 650,000 years,

### Usage Examples

In [15]:
query_1 = "What are the impacts of climate change on biodiversity?"
query_2 = "Please summarize the climate change article"
query_3 = "Describe the social and economic influence of climate change."

for query in [query_1, query_2, query_3]:
    process_query(query, memory_store, chunks_vector_store)


Processing Query: What are the impacts of climate change on biodiversity?
********************
retrieval query all: ['Effects include melting ice caps, increased coastal erosion, and a rise in extreme weather events. These changes threaten ecosystems, human health, and global food security.', 'What changes in ecosystems are mentioned as a result of climate change?  ', 'How does climate change threaten global food security?  ', 'What specific effects of climate change are highlighted that may impact biodiversity?  ', 'In what ways do rising temperatures and extreme weather events affect natural habitats?  ']
********************
retrieval query no shorts: ['Effects include melting ice caps, increased coastal erosion, and a rise in extreme weather events. These changes threaten ecosystems, human health, and global food security.', 'What changes in ecosystems are mentioned as a result of climate change?  ', 'How does climate change threaten global food security?  ', 'What specific effect

## Comparison