In [None]:
!pip install langchain chromadb sentence-transformers langchain_groq numpy scikit-learn langchain_chroma langchain_huggingface  langchain-community

In [1]:
from google.colab import userdata
import os
GROQ_API_KEY = userdata.get('GROQ_API_KEY')
os.environ['GROQ_API_KEY'] = GROQ_API_KEY

In [2]:
# Initialize multi-tier cache dictionaries
embedding_cache = {}    # Tier 1: Embedding cache
search_cache = {}       # Tier 2: Search results cache
response_cache = {}     # Tier 3: Generated response cache
cache_metadata = {}     # TTL and hit count tracking

In [3]:
from langchain_huggingface import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

In [4]:
from langchain_groq import ChatGroq
llm = ChatGroq(model="llama-3.3-70b-versatile")

In [5]:
import time
import hashlib

def get_cache_key(text):
    """Generate consistent cache key from text"""
    return hashlib.md5(text.encode()).hexdigest()

def is_cache_valid(cache_key, ttl_seconds=3600):
    """Check if cache entry is still valid based on TTL"""
    if cache_key not in cache_metadata:
        return False
    timestamp = cache_metadata[cache_key]['timestamp']
    return (time.time() - timestamp) < ttl_seconds


def update_cache_metadata(cache_key):
    """Update cache metadata with timestamp and hit count"""
    if cache_key not in cache_metadata:
        cache_metadata[cache_key] = {'hits': 0, 'timestamp': time.time()}
    cache_metadata[cache_key]['hits'] += 1
    cache_metadata[cache_key]['timestamp'] = time.time()

In [6]:
def embed_with_cache(text):
    """Embed text with caching layer"""
    cache_key = get_cache_key(text)

    if cache_key in embedding_cache and is_cache_valid(cache_key):
        update_cache_metadata(cache_key)
        return embedding_cache[cache_key]

    embedding = embedding_model.embed_query(text)
    embedding_cache[cache_key] = embedding
    update_cache_metadata(cache_key)
    return embedding

In [7]:
def retrieve_with_cache(query, retriever):
    """Retrieve documents with search cache"""
    cache_key = get_cache_key(query)

    if cache_key in search_cache and is_cache_valid(cache_key):
        update_cache_metadata(cache_key)
        return search_cache[cache_key]

    # Perform new search
    results = retriever.get_relevant_documents(query)
    search_cache[cache_key] = results
    update_cache_metadata(cache_key)
    return results

In [8]:
from langchain.chains import RetrievalQA

def generate_with_cache(query, retriever):
    """Generate response with full pipeline caching"""
    cache_key = get_cache_key(query)

    if cache_key in response_cache and is_cache_valid(cache_key):
        update_cache_metadata(cache_key)
        return response_cache[cache_key]

    # Generate new response
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever
    )
    response = qa_chain.invoke(query)
    response_cache[cache_key] = response
    update_cache_metadata(cache_key)
    return response

In [9]:
from langchain.document_loaders import WebBaseLoader
from langchain.vectorstores import Chroma

url = "https://en.wikipedia.org/wiki/Retrieval-augmented_generation"
loader = WebBaseLoader(url)
docs = loader.load()

texts = []
for doc in docs:
    content = doc.page_content
    chunk_size = 1000
    chunks = [content[i:i+chunk_size] for i in range(0, len(content), chunk_size)]
    texts.extend(chunks)

print(f"Loaded {len(texts)} text chunks")

vectorstore = Chroma.from_texts(
    texts=texts,
    embedding=embedding_model,
    collection_name="rag_cache_demo"
)

retriever = vectorstore.as_retriever(search_kwargs={"k": 3})



Loaded 27 text chunks


In [11]:
queries = [
    "What is Retrieval-Augmented Generation?",
    "How does RAG work?",
    "What are the benefits of RAG?",
    "What is Retrieval-Augmented Generation?"  # Duplicate to test cache
]

print("=== RAG with Caching Demo ===\n")

for i, query in enumerate(queries, 1):
    print(f"Query {i}: {query}")

    # Time the retrieval
    start_time = time.time()
    docs = retrieve_with_cache(query, retriever)
    retrieval_time = time.time() - start_time

    # Time the generation
    start_time = time.time()
    response = generate_with_cache(query, retriever)
    generation_time = time.time() - start_time

    print(f"Retrieved {len(docs)} documents in {retrieval_time:.3f}s")
    print(f"Generated response in {generation_time:.3f}s")
    print(f"Response: {response['result'][:200]}...\n") # Access the 'result' key before slicing

=== RAG with Caching Demo ===

Query 1: What is Retrieval-Augmented Generation?
Retrieved 3 documents in 0.000s
Generated response in 0.000s
Response: Retrieval-Augmented Generation (RAG) is a technology used in artificial intelligence, particularly in large language models. It is designed to improve the accuracy and reliability of generated text by...

Query 2: How does RAG work?
Retrieved 3 documents in 0.023s
Generated response in 0.961s
Response: RAG (not explicitly defined in the text, but presumably "Retrieval-Augmented Generation") works by incorporating information retrieval before generating responses. Here's a step-by-step explanation:

...

Query 3: What are the benefits of RAG?
Retrieved 3 documents in 0.020s
Generated response in 0.714s
Response: The benefits of RAG include:

1. Improved accuracy of large language models (LLMs) by incorporating information retrieval before generating responses.
2. Reduced need for frequent model retraining, wh...

Query 4: What is Retrieva

In [13]:
def clear_all_caches():
    """Clear all cache tiers"""
    embedding_cache.clear()
    search_cache.clear()
    response_cache.clear()
    cache_metadata.clear()
    print("All caches cleared")

def get_cache_stats():
    """Display cache statistics"""
    stats = {
        'embedding_cache_size': len(embedding_cache),
        'search_cache_size': len(search_cache),
        'response_cache_size': len(response_cache),
        'total_hits': sum(meta['hits'] for meta in cache_metadata.values())
    }
    return stats

def expire_old_cache(max_age_seconds=3600):
    """Remove expired cache entries"""
    current_time = time.time()
    expired_keys = []

    for key, meta in cache_metadata.items():
        if (current_time - meta['timestamp']) > max_age_seconds:
            expired_keys.append(key)

    for key in expired_keys:
        embedding_cache.pop(key, None)
        search_cache.pop(key, None)
        response_cache.pop(key, None)
        cache_metadata.pop(key, None)

    print(f"Expired {len(expired_keys)} cache entries")

In [14]:
print("=== Cache Statistics ===")
stats = get_cache_stats()
for key, value in stats.items():
    print(f"{key}: {value}")

print("\n=== Cache Hit Details ===")
for key, meta in cache_metadata.items():
    print(f"Key: {key[:8]}... | Hits: {meta['hits']} | Age: {time.time() - meta['timestamp']:.1f}s")

=== Cache Statistics ===
embedding_cache_size: 0
search_cache_size: 3
response_cache_size: 3
total_hits: 10

=== Cache Hit Details ===
Key: c2842ace... | Hits: 6 | Age: 50.6s
Key: 28f28084... | Hits: 2 | Age: 51.3s
Key: 083ab3f7... | Hits: 2 | Age: 50.6s


In [16]:
# Test cache performance with repeated queries
print("=== Cache Performance Test ===")

test_query = "Explain the main components of RAG"

# First run (no cache)
start = time.time()
response1 = generate_with_cache(test_query, retriever)
first_run_time = time.time() - start

# Second run (with cache)
start = time.time()
response2 = generate_with_cache(test_query, retriever)
cached_run_time = time.time() - start

print(f"First run: {first_run_time:.3f}s")
print(f"Cached run: {cached_run_time:.3f}s")
print(f"Speed improvement: {first_run_time/cached_run_time:.1f}x faster")
print(f"Same response: {response1 == response2}")

# Cache management demo
print(f"\nBefore cleanup: {get_cache_stats()}")
expire_old_cache(max_age_seconds=1)  # Expire very recent for demo
print(f"After cleanup: {get_cache_stats()}")

=== Cache Performance Test ===
First run: 0.000s
Cached run: 0.000s
Speed improvement: 2.4x faster
Same response: True

Before cleanup: {'embedding_cache_size': 0, 'search_cache_size': 0, 'response_cache_size': 1, 'total_hits': 4}
Expired 0 cache entries
After cleanup: {'embedding_cache_size': 0, 'search_cache_size': 0, 'response_cache_size': 1, 'total_hits': 4}
