# MemoRAG: Enhancing Retrieval-Augmented Generation with Memory Models

## Overview
MemoRAG is a Retrieval-Augmented Generation (RAG) framework that incorporates a memory model as an auxiliary step before the retrieval phase. In doing so, it bridges the gap in contextual understanding and reasoning that standard RAG techniques face when addressing queries with implicit or ambiguous information needs and unstructured external knowledge.

## Motivation
Standard RAG techniques rely heavily on lexical or semantic matching between the query and the knowledge base. While this approach works well for clear question answering tasks with structured knowledge, it often falls short when handling queries with implicit or ambiguous information (e.g., describing the relationships between main characters in a novel) or when the knowledge base is unstructured (e.g., fiction books). In such cases, lexical or semantic matching seldom produces the desired outputs.

## Key Components
1. **Memory**: A compressed representation of the database created by a long-context model, designed to handle and summarize extensive inputs efficiently.
2. **Retriever**: A standard RAG retrieval model responsible for selecting relevant context from the knowledge base to support the generator.
3. **Generator**: A generative language model that produces responses by combining the query with the retrieved context, similar to standard RAG setups.

## Method Details
### 1. Memory
- The memory module serves as an auxiliary component to enhance the retriever’s ability to identify better matches between queries and relevant parts of the database. It takes the original query and the database as inputs and produces staging answers — intermediate outputs like clues, surrogate queries, or key points — which the retriever uses instead of the original query.
- Long-term memory is constructed by running a long-context model, such as Qwen2-7B-Instruct or Mistral-7B-Instruct-v0.2, over the entire database. This process generates a compressed representation of the database through an attention mechanism.
- The compressed representation is stored as key-value pairs, facilitating efficient and accurate retrieval.
- Released memory models include memorag-qwen2-7b-inst and memorag-mistral-7b-inst, derived from Qwen2-7B-Instruct and Mistral-7B-Instruct-v0.2, respectively.

### 2. Retriever
- The retriever is a standard retrieval model, adapted to take processed queries (created by the memory module as staging answers) instead of the original query.
- It outputs the retrieved **context**, which serves as the basis for generating the final answer.


### 3. Generator
- The generator produces the final response by combining the retriever’s output (retrieved context) with the original query.
- MemoRAG ensures compatibility and consistency by using the memory module’s underlying model as the default generator.

## Benefits of the Approach
1. **Extended Scope of Queries:** MemoRAG's preprocessing capabilities enable it to handle complex and long-context tasks that conventional RAG methods struggle with.

2. **Improved Accuracy:** By simplifying and adjusting queries before retrieval, MemoRAG enhances performance over standard RAG methods.

3. **Flexibility:** Adapts to diverse tasks, datasets, and retrieval scenarios.

4. **Robustness:** Improved performance remains consistent across various generators, datasets, and query types.

5. **Efficiency**: The use of key-value compression reduces computational overhead.

## Conclusion
The memory module in MemoRAG significantly enhances comprehension of both the queries and the database, enabling more effective retrieval. Its ability to preprocess queries, generate staging answers, and leverage long-context memory models ensures high-quality responses, making MemoRAG a significant step forward in the evolution of retrieval-augmented generation.


<div style="text-align: center;">

<img src="../images/memo_rag.svg" alt="MemoRAG" style="width:100%; height:auto;">
</div>

## Implementation

### Imports

In [1]:
import os
from weakref import finalize

from appdirs import system
from dotenv import load_dotenv
# from typing import List
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_experimental.graph_transformers.llm import system_prompt
# from langchain_community.vectorstores import FAISS
# from langchain_community.document_loaders import PyPDFLoader
from openai import OpenAI
from helper_functions import *
from semantic_text_splitter import TextSplitter
from itertools import chain



### OpenAI Setup

In [2]:
load_dotenv()
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

### Memory Module Classes

In [3]:
class MemoryStore:
    """The MemoryStore class is a realization of the Memory Module discussed in the paper.
    Its 'memorize' method is used to create a prompt cache mimicking a key-value compression of the original text (database).
    This cache can then be used by the 'create_retrieval_prompt' method for creating the processed query to be used later for retrieval"""

    def __init__(self):
        self.embeddings = OpenAIEmbeddings()
        self.store = None
        self.processed_count = 0

    def memorize(self,
                 document: str,
                 memory_prompt: str,
                 chunk_size: int = 4096):

        # self.reset()
        # batch_size = 10

        system_prompt = """\
            You are a highly capable assistant with expertise in reading comprehension
            and information extraction. You have been given an article, and you will be asked to perform tasks or answer questions based on that text.

            Carefully study the content provided. Use only the information you find within
            the text or your general knowledge (when relevant and consistent with the text).
            Do not fabricate details. If you do not have enough information to answer a question or complete a task, acknowledge this clearly.

            Communicate your answers or completed tasks in a concise, straightforward manner.
            """

        print(f"Forming prompt cache for the full context")
        _ = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": memory_prompt.format(document=document)}
            ]
        )

    def create_retrieval_query(self, query: str, context: str) -> str:
        """Generate staging answers y = Θ_mem(q, D | θ_mem)
        This should provide rough clues/outline to guide context retrieval"""

        memorag_span_prompt = """
            You are given a question related to the article. To answer it effectively, you need to recall specific details from the article. Your task is to identify and extract one or more specific clue texts from the article that are relevant to the question.

            ### Question: {question}
            ### Instructions:
            1. You have a general understanding of the article. Your task is to generate one or more specific clues that will help in searching for supporting evidence within the article.
            2. The clues are in the form of text spans that will assist in answering the question.
            3. Only output the clues. If there are multiple clues, separate them with a newline.
            """
        memorag_sur_prompt = """
            You are given a question related to the article. To answer it effectively, you need to recall specific details from the article. Your task is to generate precise clue questions that can help locate the necessary information.

            ### Question: {question}
            ### Instructions:
            1. You have a general understanding of the article. Your task is to generate one or more specific clues that will help in searching for supporting evidence within the article.
            2. The clues are in the form of precise surrogate questions that clarify the original question.
            3. Only output the clues. If there are multiple clues, separate them with a newline.
            """
        system_prompt = "" ## need to define this

        text_spans = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": memory_prompt.format(document=context) + memorag_span_prompt.format(question=query)}
            ]
        )
        surrogate_queries = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": memory_prompt.format(document=context) + memorag_sur_prompt.format(question=query)}
            ],
            temperature=0.5  # Allow for some creativity in generating clues
        )

        retrieval_query = text_spans.split("\n") + surrogate_queries.split("\n")
        retrieval_query = [q for q in retrieval_query if len(q.split()) > 3]
        retrieval_query.append(query)

        return retrieval_query


    # def __call__(
    #     self,
    #     query: str = None,
    #     context: str = None,
    #     task_type: str = "memorag",
    #     memory_prompt: str = None,
    #     prompt_template: str = None,
    #     max_new_tokens: int = 256
    # ):
    #
    #     ## OR
    #
    #     topk_scores, topk_indices = retriever.search(queries=retrieval_query)
    #     topk_indices = list(chain(*[topk_index.tolist() for topk_index in topk_indices]))
    #     topk_indices = sorted(set([x for x in topk_indices if x > -1]))
    #     retrieval_results = [retrieval_corpus[i].strip() for i in topk_indices]
    #     knowledge = "\n\n".join(retrieval_results)
    #     # final_answer = _generate_response("qa_gen", query, knowledge, prompt_template, max_new_tokens)

    def save_store(self, path: str):
        if self.store:
            self.store.save_local(path)

    def load_store(self, path: str):
        self.store = FAISS.load_local(path, self.embeddings, allow_dangerous_deserialization=True)

### Retrieval Function

In [4]:
def retrieve_context(retrieval_query: str, vectorstore) -> List[str]:
    """Retrieve relevant context using the staging answer and the database vectorstore.
    Implements c = Γ(y, D | γ) from the paper"""

    results = vectorstore.similarity_search(retrieval_query, k=5)
    contexts = [doc.page_content for doc in results]

    return contexts

### Generation Function

In [5]:
def generate_answer(query: str, contexts: List[str]) -> str:
    """Generate final answer y = Θ(q, c | θ)"""
    prompt = f"""Based on the provided context, answer the query.

Query: {query}

Retrieved Information:
{' '.join(contexts)}

Provide a clear and concise answer focusing only on the retrieved information.
"""

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": "You are a knowledgeable assistant. Provide clear, concise answers."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=500
    )

    return response.choices[0].message.content

### Query Processing Function

In [6]:
def process_query(query: str, memory_store, vectorstore):
    print("\nProcessing Query:", query)
    print("=" * 50)

    # y = Θ_mem(q, D | θ_mem)
    retrieval_query = memory_store.create_retrieval_query(query)
    print(f"Staging Answer:\n{retrieval_query}")

    # c = Γ(y, D | γ)
    contexts = retrieve_context(retrieval_query, vectorstore)
    print(f"Retrieved Context Example: {contexts[0]}")

    # y = Θ(q, c | θ)
    final_answer = generate_answer(query, contexts)
    print(f"\nFinal Answer: {final_answer}")

    return contexts, final_answer

In [None]:
memory_prompt = """You are provided with a long article. Read the article carefully. After reading, you will be asked to perform specific tasks based on the content of the article.

                    Now, the article begins:
                    - **Article Content:** {document}

                    The article ends here.

                    Next, follow the instructions provided to complete the tasks.
                """

### Initialize Components

In [7]:
# Initialize memory store
memory_store = MemoryStore()

# Load and process document
path = "../data/Understanding_Climate_Change.pdf"
loader = PyPDFLoader(path)
documents = loader.load()
document_text = '\n'.join([doc.page_content for doc in documents])
memory_store.memorize(memory_prompt, document_text)
## pipe = mem_store.call(query ...)
chunks_vector_store = encode_pdf(path, chunk_size=1000, chunk_overlap=200)

# retrieval_chunk_size = 2048
#         text_splitter_for_retrieval = TextSplitter.from_tiktoken_model(
#             "gpt-4o-mini", retrieval_chunk_size)
#         retrieval_corpus = text_splitter_for_retrieval.chunks(document)

Processing chunk 1...
Processed 1 chunks so far


### Usage Examples

In [8]:
query_1 = "What are the impacts of climate change on biodiversity?"
query_2 = "Please summarize the climate change article"
query_3 = "Describe the social and economic influence of climate change."

for query in [query_1, query_2, query_3]:
    process_query(query, memory_store, chunks_vector_store)


Processing Query: What are the impacts of climate change on biodiversity?
Staging Answer:
I. Introduction
    A. Definition of climate change
    B. Overview of its causes and impacts

II. Effects of Climate Change
    A. Rising temperatures
    B. Heatwaves
    C. Changing seasons
    D. Melting ice and rising sea levels
    E. Glacial retreat
    F. Coastal erosion
    G. Extreme weather events

III. Impact on Biodiversity
    A. Deforestation in tropical rainforests
    B. Global carbon cycles
    C. Impacts on biodiversity

IV. Modern Scientific Observations
    A. Rapid increases in global temperatures
    B. Rising sea levels
    C. Extreme weather events
    D. Human-induced contributions to climate change

V. Agriculture's Contribution to Climate Change
    A. Livestock emissions
    B. Rice cultivation
    C. Use of synthetic fertilizers
    D. Release of greenhouse gases

VI. Conclusion
    A. Summary of key points regarding the impacts of climate change on biodiversity
Retr

## Comparison