# MnemosyneRAG: A Lightweight RAG Implementation with Advanced Caching

A fast and efficient RAG system built with ChromaDB, featuring multi-level caching and asynchronous operations.

Features:
- Two-level caching (memory for embeddings, ChromaDB for responses)
- Asynchronous background caching
- Thread-safe operations
- Response times < 50ms on cache hits

## Setup and Dependencies
First, let's import all required libraries and set up our environment.

````
pip install openai chromadb pydantic python-dotenv tenacity
````

Create a .env file with the content from below. Make sure you fulfill the required information

```
OPENAI_API_KEY=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
CLIENT_ID=1
CLIENT_SECRET=2
ANONYMIZED_TELEMETRY=False
```

You will need to create a chroma_db vector db with two collections, kb_knowledge to store the KB data and query_cache to use it as temporal caching.

In [None]:
import os
from typing import List, Dict, Optional
from pydantic import BaseModel, Field
from openai import OpenAI
import chromadb
import threading
from tenacity import retry, wait_random_exponential, stop_after_attempt
from functools import lru_cache
import json
import uuid
from dotenv import load_dotenv
import hashlib
from concurrent.futures import ThreadPoolExecutor
import asyncio

# Load environment variables
load_dotenv()

## Configuration
Define our configuration settings

In [None]:
class Config:
    CHROMA_PATH: str = "./chroma_db"
    COLLECTION_NAME: str = "kb_knowledge"
    CACHE_COLLECTION_NAME: str = "query_cache"
    SIMILARITY_THRESHOLD: float = 0.9
    CACHE_SIZE: int = 1000
    DEFAULT_N_RESULTS: int = 3

## Initialize Services
Set up our OpenAI and ChromaDB clients

In [None]:
def init_services():
    # Initialize OpenAI client
    openai_client = OpenAI()
    
    # Initialize ChromaDB
    chroma_client = chromadb.PersistentClient(path=Config.CHROMA_PATH)
    kb_collection = chroma_client.get_collection(name=Config.COLLECTION_NAME)
    cache_collection = chroma_client.get_or_create_collection(name=Config.CACHE_COLLECTION_NAME)
    
    return openai_client, kb_collection, cache_collection

# Initialize services
openai_client, kb_collection, cache_collection = init_services()

## Data Models
Define our Pydantic models for type safety

In [None]:
class UserQuery(BaseModel):
    question: str = Field(..., min_length=1, max_length=1000)
    
class CacheResponse(BaseModel):
    response: str
    documents: List[List[str]]
    metadata: List[List[Dict]]
    links: List[str]

## Embedding Functions
Functions for handling embeddings with caching

In [None]:
@lru_cache(maxsize=Config.CACHE_SIZE)
def get_cached_embedding(text: str) -> tuple:
    """Cache embeddings for frequently asked questions"""
    embedding = get_embedding(text)
    return tuple(embedding)

@retry(
    wait=wait_random_exponential(min=1, max=20),
    stop=stop_after_attempt(6)
)
def get_embedding(text: str) -> List[float]:
    """Get embeddings with retry logic"""
    try:
        return openai_client.embeddings.create(
            input=[text],
            model="text-embedding-3-small"
        ).data[0].embedding
    except Exception as e:
        raise Exception(f"Embedding generation failed: {str(e)}")

## Caching Implementation

In [None]:
class EmbeddingCache:
    def __init__(self, max_size=1000):
        self.cache = {}
        self.max_size = max_size
        self._lock = threading.Lock()

    def get(self, key: str) -> Optional[List[float]]:
        with self._lock:
            return self.cache.get(key)

    def set(self, key: str, value: List[float]):
        with self._lock:
            if len(self.cache) >= self.max_size:
                self.cache.pop(next(iter(self.cache)))
            self.cache[key] = value

# Initialize global cache and thread pool
embedding_cache = EmbeddingCache()
thread_pool = ThreadPoolExecutor(max_workers=4)

## Cache Operations

In [None]:
def cache_response_background(query: str, response: CacheResponse):
    """Background task for caching responses"""
    try:
        cache_key = hashlib.md5(query.encode()).hexdigest()
        embedding = embedding_cache.get(cache_key)
        
        if not embedding:
            embedding = get_embedding(query)
            embedding_cache.set(cache_key, embedding)

        cache_id = str(uuid.uuid4())
        
        cache_collection.add(
            ids=[cache_id],
            documents=[query],
            embeddings=[embedding],
            metadatas=[{
                "response": response.response,
                "documents": json.dumps(response.documents),
                "metadata": json.dumps(response.metadata),
                "links": json.dumps(response.links)
            }]
        )
        print(f"Successfully cached response with ID: {cache_id}")
    except Exception as e:
        print(f"Background caching failed: {str(e)}")

async def check_cache(query: str) -> Optional[CacheResponse]:
    """Check cache for similar queries"""
    try:
        cache_key = hashlib.md5(query.encode()).hexdigest()
        embedding = embedding_cache.get(cache_key)
        
        if not embedding:
            embedding = get_embedding(query)
            embedding_cache.set(cache_key, embedding)

        results = cache_collection.query(
            query_embeddings=[embedding],
            n_results=1
        )
        
        if not results["ids"] or not results["metadatas"]:
            return None

        if results["distances"][0][0] > 0.1:
            return None

        metadata = results["metadatas"][0][0]
        return CacheResponse(
            response=metadata["response"],
            documents=json.loads(metadata["documents"]),
            metadata=json.loads(metadata["metadata"]),
            links=json.loads(metadata["links"])
        )
    except Exception as e:
        print(f"Cache check failed: {str(e)}")
        return None

## Main Query Processing

In [None]:
async def process_query(query: str) -> CacheResponse:
    """Process a query through the RAG system"""
    try:
        # Check cache first
        cached_response = await check_cache(query)
        if cached_response:
            print("Cache hit! Returning cached response")
            return cached_response

        print("Cache miss, generating new response")
        
        # Query knowledge base
        kb_results = kb_collection.query(
            query_texts=[query],
            n_results=1
        )
        
        if not kb_results["documents"] or not kb_results["documents"][0]:
            raise ValueError("No relevant documents found")

        # Generate context and prepare messages
        context = "\n\n".join(kb_results["documents"][0])
        messages = [
            {
                "role": "system",
                "content": "You are a helpful assistant. If a user's request is unclear, request clarification."
            },
            {
                "role": "user",
                "content": f"Context: {context}\n\nQuestion: {query}"
            }
        ]
        
        # Generate LLM response
        response = openai_client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages
        )
        
        result = CacheResponse(
            response=response.choices[0].message.content,
            documents=kb_results["documents"],
            metadata=kb_results["metadatas"],
            links=[f"https://example.link/{meta['number']}" for meta in kb_results["metadatas"][0]]
        )
        
        # Cache in background
        loop = asyncio.get_event_loop()
        loop.run_in_executor(
            thread_pool,
            cache_response_background,
            query,
            result
        )
        
        return result
            
    except Exception as e:
        raise Exception(f"Failed to process question: {str(e)}")

## Example Usage
Test the RAG system with a sample query

In [None]:
async def test_system():
    query = "What is machine learning?"
    try:
        result = await process_query(query)
        print("Response:", result.response)
        print("\nDocuments:", result.documents)
        print("\nLinks:", result.links)
    except Exception as e:
        print(f"Error: {str(e)}")

# Run the test
await test_system()

## Cache Management
Functions to manage the cache

In [None]:
async def clear_cache(cache_id: Optional[str] = None):
    """Clear the cache either completely or for a specific ID"""
    try:
        if cache_id:
            cache_collection.delete(ids=[cache_id])
            print(f"Cache cleared for ID: {cache_id}")
            return
            
        all_ids = cache_collection.get()["ids"]
        if all_ids:
            cache_collection.delete(ids=all_ids)
            print("Complete cache cleared successfully")
    except Exception as e:
        print(f"Cache clearing failed: {str(e)}")

## Additional Example: Multiple Queries
Test the system with multiple queries to demonstrate caching

In [None]:
async def test_multiple_queries():
    queries = [
        "What is machine learning?",
        "Explain artificial intelligence",
        "What is machine learning?"  # Repeated query to test cache
    ]
    
    for query in queries:
        print(f"\nProcessing query: {query}")
        try:
            result = await process_query(query)
            print("Response:", result.response)
        except Exception as e:
            print(f"Error: {str(e)}")
        
# Run multiple queries test
await test_multiple_queries()

## Performance Analysis
Measure response times with and without cache

In [None]:
import time

async def measure_performance():
    query = "What is machine learning?"
    
    # First query (no cache)
    start_time = time.time()
    result1 = await process_query(query)
    first_query_time = time.time() - start_time
    
    # Second query (should hit cache)
    start_time = time.time()
    result2 = await process_query(query)
    second_query_time = time.time() - start_time
    
    print(f"First query (no cache): {first_query_time:.2f} seconds")
    print(f"Second query (cache hit): {second_query_time:.2f} seconds")

# Run performance test
await measure_performance()