In [7]:
import uuid
import chromadb
from sentence_transformers import SentenceTransformer

class VectorStore:
    def __init__(self, collection_name="chat_memory"):
        self.client = chromadb.Client()
        self.collection = self.client.create_collection(name=collection_name, get_or_create=True)
        # Load a small, fast embedding model (approx 80MB)
        self.embedder = SentenceTransformer('all-MiniLM-L6-v2')

    def search(self, query: str, k=10):
        # 1. Convert query to vector
        query_vec = self.embedder.encode(query).tolist()
        # 2. Search ChromaDB
        results = self.collection.query(query_embeddings=[query_vec], n_results=k)
        # 3. Unpack ugly Chroma format into a clean list
        return [
            {'text': results['documents'][0][i], 'id': results['ids'][0][i]} 
            for i in range(len(results['documents'][0]))
        ]

    def add(self, text: str, metadata: dict | None = None):
        if metadata is None:
            metadata = {}

        metadata.setdefault("created_at", time.time())

        vector = self.embedder.encode(text).tolist()

        self.collection.add(
            ids=[str(uuid.uuid4())],
            documents=[text],
            embeddings=[vector],
            metadatas=[metadata]
        )


In [8]:
from rank_bm25 import BM25Okapi

class KeywordIndex:
    def __init__(self):
        self.corpus = []
        self.bm25 = None

    def add(self, text: str):
        self.corpus.append(text)
        # Note: In production, use a real index (Lucene/Elastic). 
        # Rebuilding BM25 every turn is fine for prototypes but O(N) slow.
        tokenized_corpus = [doc.lower().split() for doc in self.corpus]
        self.bm25 = BM25Okapi(tokenized_corpus)

    def search(self, query: str, k=10):
        if not self.bm25: return []
        tokenized_query = query.lower().split()
        return self.bm25.get_top_n(tokenized_query, self.corpus, n=k)

In [9]:
from sentence_transformers import CrossEncoder

class Reranker:
    def __init__(self):
        # A model specifically trained to score (Query, Document) pairs
        self.model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

    def rank(self, query: str, docs: list, top_k=3):
        if not docs: return []
        # Score every pair
        pairs = [[query, doc] for doc in docs]
        scores = self.model.predict(pairs)
        # Sort by score descending
        sorted_docs = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
        return [doc for score, doc in sorted_docs][:top_k]

In [10]:
import time
import llama_cpp
from collections import deque
from typing import List, Dict, Any, Tuple

class LLMWithVectorMemory:
    """
    Chat wrapper that uses Hybrid RAG (Vector + Keyword + Reranking)
    to retrieve long-term context, while keeping a tiny short-term
    buffer for immediate flow.
    """

    # --- Static Configuration ---
    CONTEXT_WINDOW        = 4096
    MAX_GENERATION_TOKENS = 1024
    SAFETY_MARGIN         = 32

    # We keep a tiny short-term buffer so the model doesn't forget
    # what was said literally 1 second ago (RAG is for long-term)
    SHORT_TERM_K = 2

    # How many chunks to fetch from each index before fusion
    RETRIEVAL_CANDIDATES = 10
    # How many chunks to actually show the LLM after reranking
    FINAL_TOP_K = 3

    SYSTEM_PROMPT = (
        "You are an AI assistant with access to a long-term memory. "
        "A section labeled 'RELEVANT CONTEXT FROM MEMORY' may be provided below.\n"

        "INSTRUCTIONS:\n"
        "1. Use the context to answer the user's question accurately.\n"
        "2. If the user in the context provides a fact (e.g., 'My name is X', 'My job is Y'), "
        "TRUST THE USER, even if the assistant in the context previously denied knowing it.\n"
        "3. If the answer is not in the context, rely on your general knowledge, or take it as a new information if it's a fact presented by the user.\n"
        "4. User might provide you new information or update existing one."
        "5. DO NOT mention the 'RELEVANT CONTEXT FROM MEMORY' words explicitly."
        "6. Be concise."
    )

    def __init__(
            self,
            model_path: str = "/home/ubuntu/ai-engineering/models/qwen2.5-7b-instruct-q5_k_m-00001-of-00002.gguf"
    ):
        # --- 1. Llama Initialization ---
        self.llm = llama_cpp.Llama(
            model_path=model_path,
            n_gpu_layers=-1,
            n_ctx=self.CONTEXT_WINDOW,
            verbose=False
        )

        # --- 2. Memory Components Initialization ---
        print("Initializing Memory Components... (This may take a moment)")
        self.vector_store = VectorStore(collection_name="chat_memory")
        self.keyword_index = KeywordIndex()
        self.reranker = Reranker()

        # --- 3. State ---
        # Short-term buffer (Deque)
        self.short_term_history: deque = deque(maxlen=self.SHORT_TERM_K * 2)

        self.generation_params = {
            "temperature": 0.6, # Lower temp for RAG to reduce hallucination
            "top_p": 0.9,
            "stop": ["<|im_end|>", "<|endoftext|>"],
            "max_tokens": self.MAX_GENERATION_TOKENS,
        }

    # --- Public API ---

    def answer(self, user_text: str) -> str:
        """
        1. Retrieve relevant context (Hybrid + Rerank).
        2. Build Prompt (System + Context + Short History + User)
        3. Generate.
        4. Save to Memory.
        """

        # 1. Retrieve Context
        # We retrieve based on the user's CURRENT input
        retrieved_context = self._retrieve_context(user_text)

        # 2. Build Prompt
        messages = self._build_rag_prompt(user_text, retrieved_context)

        # 3. Call Model
        reply = self.llm.create_chat_completion(
            messages=messages,
            **self.generation_params
        )["choices"][0]["message"]["content"]

        # 4. Update State (Short-term & Long-term)
        self._update_memory(user_text, reply)

        return reply
    
    def print_debug(self, last_context: List[str]):
        """Helper to see what the RAG actually found."""
        print("\n--- RAG DEBUG: What the model saw ---")
        if not last_context:
            print("(No relevant memories found)")
        else:
            for i, ctx in enumerate(last_context):
                print(f"[{i+1}] {ctx[:100]}...") # Print first 100 chars
        print("---------------------------------------\n")
    
    # --- Internal Orchestration ---

    def _retrieve_context(self, query: str) -> List[str]:
        """
        Executes the Hybrid Search + Reranking Pipeline.
        """
        # A. Parallel Retrieval
        # Get dense results (Vector)
        # Note: VectorStore.search returns dicts with 'text', 'id', etc.
        vector_results = self.vector_store.search(query, k=self.RETRIEVAL_CANDIDATES)
        vector_texts = [r['text'] for r in vector_results]

        # Get sparse results (Keyword)
        # Note: KeywordIndex.search returns plain strings
        keyword_texts = self.keyword_index.search(query, k=self.RETRIEVAL_CANDIDATES)

        # B. Fusion (Simple Set Deduplication)
        all_candidates = list(set(vector_texts + keyword_texts))

        if not all_candidates:
            return []
        
        # C. Reranking
        # The heavy lifting: re-sort candidates by relevance to the specific query
        top_docs = self.reranker.rank(query, all_candidates, top_k=self.FINAL_TOP_K)

        return top_docs
    
    def _build_rag_prompt(self, user_text: str, context_chunks: List[str]) -> List[Dict]:
        """
        Constructs the chat format:
        System (with Context) -> Short Term History -> Current User
        """

        # Format context into a string
        context_str = ""
        if context_chunks:
            context_str = "\nRELEVANT CONTEXT FROM MEMORY:\n"
            for chunk in context_chunks:
                context_str += f"- {chunk}\n"
        
        # Inject context into System Prompt
        # This is 'In-Context Learning'
        system_content = f"{self.SYSTEM_PROMPT}\n{context_str}"

        messages = [{"role": "system", "content": system_content}]

        # Add Short-Term History (Sliding Window)
        # This ensures the model can handle "What about the second one?" references
        messages.extend(list(self.short_term_history))

        # Add Current User
        messages.append({"role": "user", "content": user_text})

        return messages

    def _update_memory(self, user_text: str, assistant_reply: str):
        """
        Saves the interaction to:
        1. RAM (Short-term sliding window)
        2. ChromaDB (Vector)
        3. BM25 (Keyword)
        """
        # Update Short Term
        self.short_term_history.append({"role": "user", "content": user_text})
        self.short_term_history.append({"role": "assistant", "content": assistant_reply})

        # Save to Long Term Memory
        # Strategy: Save the PAIR. Saving just user text loses context.
        # Saving just assistant text loses the prompt.
        memory_text = f"User: {user_text}\nAssistant: {assistant_reply}"

        # Metadata for filtering later (e.g., by time)
        metadata = {"timestamp": time.time(), "type": "conversation_turn"}

        # 1. Vector Write
        self.vector_store.add(memory_text, metadata)

        # 2. Keyword Write (Triggering the expensive rebuild)
        self.keyword_index.add(memory_text)

In [11]:
bot = LLMWithVectorMemory()

setup = """
Here are some things to remember:
- My deployment bucket is s3://proj-847-staging-west
- The build flag is --env=qa-cluster-3  
- My SSH alias is devbox-7b
- The port I always use is 9473
- My teammate's code review tag is @chen-review-squad
"""
print(bot.answer(setup))

llama_context: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized


Initializing Memory Components... (This may take a moment)
Sure, here's a summary of the information you provided:

- Deployment bucket: s3://proj-847-staging-west
- Build flag: --env=qa-cluster-3
- SSH alias: devbox-7b
- Port: 9473
- Teammate's code review tag: @chen-review-squad

Is there anything specific you need help with regarding this information?


In [None]:
distractors = [
    "How do I center a div?",
    "What is the capital of France?",
    "Explain the difference between TCP and UDP.",
    "Write a haiku about coding.",
    "What is 12 * 12?"
]
for d in distractors:
    print(f"User: {d}")
    print(f"Bot: {bot.answer(d)}")

User: How do I center a div?
