In [1]:
# step 1: Install necessary libraries
%pip install -q chromadb
%pip install -q sentence-transformers


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Step 2: Import necessary libraries
import re
import uuid
import chromadb
from sentence_transformers import SentenceTransformer, CrossEncoder

COLLECTION_NAME = "semantic_cache"
client = chromadb.Client()

# Clean start when running all cells
try:
    client.delete_collection(COLLECTION_NAME)
    print("✓ Cleared existing semantic_cache collection")
except Exception:
    print("✓ No existing collection to clear")

collection = client.get_or_create_collection(
    name=COLLECTION_NAME,
    metadata={"hnsw:space": "cosine"}
)

def _new_id():
    return str(uuid.uuid4())


  from .autonotebook import tqdm as notebook_tqdm


✓ No existing collection to clear


In [3]:
# Step 3: Cache Entry Structure
class SemanticCache:
    def __init__(self, embedder_model='all-MiniLM-L6-v2', collection_ref=None):
        """Initialize embedding model and reuse the shared ChromaDB collection."""
        self.embedder = SentenceTransformer(embedder_model)
        self.collection = collection_ref if collection_ref is not None else collection

    def add(self, query, soln_path):
        """Add query-solution path pair to cache"""
        embedding = self.embedder.encode(query).tolist()
        self.collection.add(
            embeddings=[embedding],
            documents=[query],
            metadatas=[{'query': query, 'soln_path': soln_path}],
            ids=[_new_id()]
        )

    def search(self, query, threshold=0.75):
        """Search for similar cached query"""
        embedding = self.embedder.encode(query).tolist()
        results = self.collection.query(query_embeddings=[embedding], n_results=1)
        
        if results.get('distances') and results['distances'][0]:
            score = 1 - results['distances'][0][0]
            if score >= threshold:
                metadata = results['metadatas'][0][0]
                return {
                    'soln_path': metadata['soln_path'],
                    'score': score,
                    'cached_query': metadata.get('query')
                }
        return None

# Test it
cache = SemanticCache()
cache.add("What is the capital of France?", "lookup_fact('France', 'capital')")
cache.add("How do I reset my password?", "get_help_article('password_reset')")
cache.add("What are your business hours?", "get_business_info('hours')")

test_queries = [
    "What's the capital of France?",
    "Password reset instructions",
    "When are you open?",
    "What's the weather today?"
]

for q in test_queries:
    # Get the raw similarity score even if below threshold
    embedding = cache.embedder.encode(q).tolist()
    results = cache.collection.query(query_embeddings=[embedding], n_results=1)
    
    if results.get('distances') and results['distances'][0]:
        score = 1 - results['distances'][0][0]
        
        # Now check against threshold
        result = cache.search(q)
        if result:
            print(f"✅ '{q}' → '{result['soln_path']}' (score: {result['score']:.2f})")
        else:
            print(f"❌ '{q}' → Match below threshold (score: {score:.2f})")
    else:
        print(f"❌ '{q}' → No match (no cached queries)")

✅ 'What's the capital of France?' → 'lookup_fact('France', 'capital')' (score: 0.99)
✅ 'Password reset instructions' → 'get_help_article('password_reset')' (score: 0.80)
❌ 'When are you open?' → Match below threshold (score: 0.51)
❌ 'What's the weather today?' → Match below threshold (score: 0.27)


In [4]:
# Step 4: Semantic Boundaries
class MaskedSemanticCache(SemanticCache):
    def mask_entities(self, text):
        """Replace specific entities with placeholders"""
        text = re.sub(r'\$[\d,]+', '[AMOUNT]', text)           # Money amounts
        text = re.sub(r'\b[A-Z]{2,5}\b', '[TICKER]', text)     # Tickers
        text = re.sub(r'\b20\d{2}\b', '[YEAR]', text)          # Years
        text = re.sub(r'\d+(\.\d+)?%', '[PERCENT]', text)      # Percentages
        text = re.sub(r'\S+@\S+', '[EMAIL]', text)             # Emails
        return text

    def add(self, query, response):
        """Add with entity masking"""
        masked_query = self.mask_entities(query)
        embedding = self.embedder.encode(masked_query).tolist()
        self.collection.add(
            embeddings=[embedding],
            documents=[masked_query],
            metadatas=[{
                'original_query': query,
                'masked_query': masked_query,
                'response': response
            }],
            ids=[_new_id()]
        )

    def search(self, query, threshold=0.75):
        """Search using masked query"""
        masked_query = self.mask_entities(query)
        embedding = self.embedder.encode(masked_query).tolist()
        results = self.collection.query(
            query_embeddings=[embedding],
            n_results=1
        )
        if results.get('distances') and results['distances'][0]:
            score = 1 - results['distances'][0][0]
            if score >= threshold:
                metadata = results['metadatas'][0][0]
                return {
                    'response': metadata['response'],
                    'score': score,
                    'cached_query': metadata.get('original_query'),
                    'masked_query': metadata.get('masked_query')
                }
        return None

# Test entity masking
cache = MaskedSemanticCache()
cache.add("What was AAPL stock price in 2023?", "Use stock_price_tool")
cache.add("My budget is $5000", "Use budget_tool")

print("Testing entity masking:")
result = cache.search("What was TSLA stock price in 2024?")
if result:
    print(f"✅ Matched despite different ticker and year!")
    print(f"   Original: {result['cached_query']}")
    print(f"   Masked: {result['masked_query']}")
    print(f"   Response: {result['response']}")


Testing entity masking:
✅ Matched despite different ticker and year!
   Original: What was AAPL stock price in 2023?
   Masked: What was [TICKER] stock price in [YEAR]?
   Response: Use stock_price_tool


In [5]:
# Step 5: Cross-Encoder Verification
class CrossEncoderSemanticCache(MaskedSemanticCache):
    def __init__(self, embedder_model='all-MiniLM-L6-v2', collection_ref=None):
        super().__init__(embedder_model=embedder_model, collection_ref=collection_ref)
        self.verifier = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

    def search_with_verification(self, query, vector_threshold=0.7, verify_threshold=3.5):
        """Two-stage search: vector similarity + verification"""
        masked_query = self.mask_entities(query)
        embedding = self.embedder.encode(masked_query).tolist()
        results = self.collection.query(query_embeddings=[embedding], n_results=3)
        if not (results.get('distances') and results['distances'][0]):
            return None

        best_match, best_score = None, 0.0
        for i, distance in enumerate(results['distances'][0]):
            vector_score = 1 - distance
            if vector_score < vector_threshold:
                continue

            metadata = results['metadatas'][0][i]
            verify_score = float(self.verifier.predict(
                [[query, metadata.get('original_query', metadata.get('query', ''))]]
            )[0])
            if verify_score > best_score and verify_score >= verify_threshold:
                best_score = verify_score
                best_match = {
                    'response': metadata['response'],
                    'vector_score': vector_score,
                    'verify_score': verify_score,
                    'cached_query': metadata.get('original_query', metadata.get('query'))
                }
        return best_match

# Test verification
cache = CrossEncoderSemanticCache()
cache.add("What is my checking account balance?", "checking_balance_tool")
cache.add("What is my savings account balance?", "savings_balance_tool")
cache.add("What is my credit card balance?", "credit_balance_tool")

queries = [
    "What's my checking balance",
    "What's my savings account balance?",
    "Credit card balance"
]

print("Testing with cross-encoder verification:")
for q in queries:
    result = cache.search_with_verification(q)
    if result:
        print(f"✅ '{q}' → '{result['response']}'")
        print(f"   Vector: {result['vector_score']:.2f}, Verified: {result['verify_score']:.2f}")
    else:
        print(f"❌ '{q}' → No verified match found")


Testing with cross-encoder verification:
✅ 'What's my checking balance' → 'checking_balance_tool'
   Vector: 0.89, Verified: 4.82
✅ 'What's my savings account balance?' → 'savings_balance_tool'
   Vector: 0.99, Verified: 6.35
✅ 'Credit card balance' → 'credit_balance_tool'
   Vector: 0.88, Verified: 4.01


In [6]:
# Step 6: Adaptive Thresholds
class AdaptiveSemanticCache(CrossEncoderSemanticCache):
    def __init__(self, model_name='all-MiniLM-L6-v2', collection_ref=None):
        super().__init__(embedder_model=model_name, collection_ref=collection_ref)
        self.model_name = model_name
        self.model_thresholds = {
            'all-MiniLM-L6-v2': 0.75,
            'all-mpnet-base-v2': 0.80,
            'all-distilroberta-v1': 0.70
        }

    def get_threshold(self, match_type='normal'):
        base = self.model_thresholds.get(self.model_name, 0.75)
        adjustments = {
            'exact': base + 0.15,
            'normal': base,
            'fuzzy': base - 0.10,
            'exploratory': base - 0.20
        }
        return adjustments.get(match_type, base)

    def adaptive_search(self, query, match_type='normal'):
        threshold = self.get_threshold(match_type)
        verify_threshold = 0.9 if match_type == 'exact' else 0.85
        return self.search_with_verification(
            query,
            vector_threshold=threshold,
            verify_threshold=verify_threshold
        )

# Test adaptive thresholds with fallback demonstration
cache = AdaptiveSemanticCache()
cache.add("What is the annual revenue?", "revenue_tool")
cache.add("Show me customer demographics", "demographics_tool")

test_cases = [
    ("yearly revenue", "Strong match"),
    ("customer demographic", "Weaker match")
]

for query, description in test_cases:
    print(f"\nQuery: '{query}' ({description})")
    for match_type in ['exact', 'normal', 'fuzzy']:
        result = cache.adaptive_search(query, match_type)
        threshold = cache.get_threshold(match_type)
        if result:
            print(f"  {match_type.upper()} (threshold {threshold:.2f}): ✅ Found match")
        else:
            print(f"  {match_type.upper()} (threshold {threshold:.2f}): ❌ No match")



Query: 'yearly revenue' (Strong match)
  EXACT (threshold 0.90): ✅ Found match
  NORMAL (threshold 0.75): ✅ Found match
  FUZZY (threshold 0.65): ✅ Found match

Query: 'customer demographic' (Weaker match)
  EXACT (threshold 0.90): ❌ No match
  NORMAL (threshold 0.75): ✅ Found match
  FUZZY (threshold 0.65): ✅ Found match


In [9]:
# Step 7: Auto-Population and Statistics
class SmartSemanticCache(AdaptiveSemanticCache):
    def __init__(self, model_name='all-MiniLM-L6-v2', collection_ref=None):
        super().__init__(model_name=model_name, collection_ref=collection_ref)
        self.stats = {'hits': 0, 'misses': 0, 'auto_added': 0}

    def query_with_fallback(self, query, fallback_fn=None, match_type='normal'):
        """Try cache first, fallback to function if miss"""
        result = self.adaptive_search(query, match_type)
        if result:
            self.stats['hits'] += 1
            return result['response'], 'cache'

        self.stats['misses'] += 1
        if fallback_fn:
            response = fallback_fn(query)
            self.add(query, response)
            self.stats['auto_added'] += 1
            return response, 'computed'

        return None, 'miss'

    def print_stats(self):
        total = self.stats['hits'] + self.stats['misses']
        if total > 0:
            hit_rate = self.stats['hits'] / total * 100
            print("Cache Stats:")
            print(f"  Hits: {self.stats['hits']} ({hit_rate:.1f}%)")
            print(f"  Misses: {self.stats['misses']}")
            print(f"  Auto-added: {self.stats['auto_added']}")

# Mock agent function
def mock_agent(query):
    """Simulate an expensive agent call"""
    if 'balance' in query.lower():
        return 'balance_tool'
    elif 'transaction' in query.lower():
        return 'transaction_tool'
    else:
        return 'general_tool'

# Test with fallback
cache = SmartSemanticCache()
queries = [
    "What is my account balance?",
    "Show me my account balance",
    "Account balance please",
    "Recent transactions",
    "Show my transactions",
]

print("Testing with auto-population:")
for q in queries:
    response, source = cache.query_with_fallback(q, mock_agent)
    print(f"'{q}' → {response} ({source})")

print()
cache.print_stats()


Testing with auto-population:
'What is my account balance?' → checking_balance_tool (cache)
'Show me my account balance' → balance_tool (cache)
'Account balance please' → balance_tool (cache)
'Recent transactions' → transaction_tool (cache)
'Show my transactions' → transaction_tool (cache)

Cache Stats:
  Hits: 5 (100.0%)
  Misses: 0
  Auto-added: 0


In [10]:
# Step 8 (Bonus step!): Fallback to Planning Agent
class CacheWithFallback:
    def __init__(self, cache_impl=None):
        # Reuse the verified cache (or inject a smarter one)
        self.cache = cache_impl if cache_impl is not None else CrossEncoderSemanticCache()
        self.stats = {'hits': 0, 'misses': 0}

    def process_query(self, query):
        """Try cache first, then fallback to agent"""
        result = self.cache.search_with_verification(query)
        if result:
            self.stats['hits'] += 1
            print(f"✓ Cache hit: {result['response']}")
            return {**result, 'from_cache': True}

        self.stats['misses'] += 1
        print("✗ Cache miss - calling planning agent...")

        solution = self.mock_planning_agent(query)

        if solution:
            self.cache.add(query, solution)
            print(f"→ Added to cache: {solution}")

        return {'solution_path': solution, 'from_cache': False}

    def mock_planning_agent(self, query):
        """Simulate a planning agent (replace with real agent)"""
        q = query.lower()
        if 'balance' in q:
            return 'balance_tool'
        elif 'transaction' in q:
            return 'transaction_tool'
        elif 'stock' in q:
            return 'stock_tool'
        else:
            return 'general_tool'

    def print_stats(self):
        total = self.stats['hits'] + self.stats['misses']
        hit_rate = (self.stats['hits'] / total) if total > 0 else 0
        print(f"\nStats: {self.stats['hits']} hits, {self.stats['misses']} misses")
        print(f"Hit rate: {hit_rate:.1%}")

# Test the full system
system = CacheWithFallback()

queries = [
    "What is my account balance?",
    "Show me my account balance",
    "What are my recent transactions?",
    "Display my transactions",
    "What is my balance?",
]

for q in queries:
    print(f"\nQuery: {q}")
    system.process_query(q)

system.print_stats()



Query: What is my account balance?
✓ Cache hit: checking_balance_tool

Query: Show me my account balance
✓ Cache hit: balance_tool

Query: What are my recent transactions?
✓ Cache hit: transaction_tool

Query: Display my transactions
✓ Cache hit: transaction_tool

Query: What is my balance?
✓ Cache hit: balance_tool

Stats: 5 hits, 0 misses
Hit rate: 100.0%
