In [None]:
import time
from typing import List
# Assuming other necessary imports like retry, types, etc. are present

class DocumentEmbedding(EmbeddingFunction):
    """
    Base class for embedding and retrieving documents using ChromaDB and GenAI embeddings.
    """

    def __init__(self, chroma_client, genai_client, model_name="models/text-embedding-004", collection_name="embeddings"):
        self.chroma_client = chroma_client
        self.genai_client = genai_client
        self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
        self.model_name = model_name

    def _embed_batch(self, texts: List[str]) -> List[List[float]]:
        """
        Helper function to embed a BATCH of texts at once.
        Returns a list of embeddings (list of list of floats).
        """
        # Google API supports embedding multiple documents in one request
        response = self.genai_client.models.embed_content(
            model=self.model_name,
            contents=texts,
            config=types.EmbedContentConfig(task_type="retrieval_document")
        )
        
        # Extract the values from each embedding object in the response
        return [emb.values for emb in response.embeddings]

    @retry.Retry(predicate=is_retriable, timeout=1000)
    def add_document_embeddings(self, documents, metadata, batch_size=100):
        """
        Embeds a list of documents in BATCHES and stores them in ChromaDB.
        """
        all_embeddings = []
        
        # --- BATCHING LOGIC START ---
        total_docs = len(documents)
        print(f"Starting embedding for {total_docs} documents...")

        for i in range(0, total_docs, batch_size):
            # 1. Slice the documents into a batch (e.g., 0 to 100)
            batch_docs = documents[i : i + batch_size]
            
            # 2. Get embeddings for this batch (1 API call instead of 100)
            print(f"Processing batch {i} to {i + len(batch_docs)}...")
            try:
                batch_embeddings = self._embed_batch(batch_docs)
                all_embeddings.extend(batch_embeddings)
                
                # Good practice: sleep slightly between batches to be nice to the rate limiter
                time.sleep(0.5) 
            except Exception as e:
                print(f"Error processing batch starting at index {i}: {e}")
                # You might want to raise here or handle partial failures depending on your needs
                raise e
        # --- BATCHING LOGIC END ---

        # 3. Add everything to ChromaDB at once
        self.collection.add(
            documents=documents,
            metadatas=metadata,
            embeddings=all_embeddings,
            ids=[str(meta[self.id_field]) for meta in metadata]
        )
        print("All documents added successfully.")

    def query_embedding(self, query, n_results=5):
        """
        Embeds the input query (single string) and searches ChromaDB.
        """
        # For a single query, we just wrap it in a list to use the same logic, 
        # but we need "retrieval_query" task type for better accuracy.
        response = self.genai_client.models.embed_content(
            model=self.model_name,
            contents=[query],
            config=types.EmbedContentConfig(task_type="retrieval_query") # Note the task type change
        )
        
        embedded_query = response.embeddings[0].values

        return self.collection.query(
            query_embeddings=[embedded_query],
            n_results=n_results
        ) if embedded_query else []

In [None]:
import sqlite3
import faiss
import numpy as np
import os

class LightWeightRAG:
    def __init__(self, db_path="rag_data.db", dimension=768):
        self.db_path = db_path
        self.dimension = dimension
        
        # 1. Setup SQLite (The Content Store)
        self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
        self.cursor = self.conn.cursor()
        self._init_db()

        # 2. Setup FAISS (The Vector Index)
        # We use IndexIDMap so we can assign specific SQLite IDs to vectors
        self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension))

    def _init_db(self):
        """Creates the table if it doesn't exist."""
        self.cursor.execute("""
            CREATE TABLE IF NOT EXISTS abstracts (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                paper_id TEXT,
                title TEXT,
                date TEXT,
                content TEXT
            )
        """)
        self.conn.commit()

    def add_documents(self, documents, embeddings):
        """
        Args:
            documents: List of dicts [{'paper_id': '...', 'title': '...', 'content': '...'}, ...]
            embeddings: List of list of floats (from Google GenAI)
        """
        # Step A: Insert text into SQLite to get unique IDs
        ids_list = []
        for doc in documents:
            self.cursor.execute("""
                INSERT INTO abstracts (paper_id, title, date, content)
                VALUES (?, ?, ?, ?)
            """, (doc.get('paper_id'), doc.get('title'), doc.get('date'), doc.get('content')))
            
            # Capture the auto-incremented ID generated by SQLite
            ids_list.append(self.cursor.lastrowid)
        
        self.conn.commit()

        # Step B: Add vectors to FAISS using the SAME IDs
        # FAISS requires float32 numpy arrays
        vectors_np = np.array(embeddings).astype('float32')
        ids_np = np.array(ids_list).astype('int64')
        
        self.index.add_with_ids(vectors_np, ids_np)
        print(f"Successfully stored {len(documents)} documents.")

    def search(self, query_embedding, k=3):
        """
        Args:
            query_embedding: List of floats (single query vector)
            k: Number of results to return
        """
        # Step A: Search FAISS for the closest vectors
        query_np = np.array([query_embedding]).astype('float32')
        distances, indices = self.index.search(query_np, k)
        
        # 'indices' contains the IDs of the matches (e.g., [5, 23, 101])
        found_ids = indices[0]
        
        # Filter out -1 (FAISS returns -1 if fewer than k results exist)
        valid_ids = [int(i) for i in found_ids if i != -1]
        
        if not valid_ids:
            return []

        # Step B: Fetch the actual text from SQLite using those IDs
        # Construct dynamic SQL query: "SELECT * FROM abstracts WHERE id IN (?,?,?)"
        placeholders = ','.join('?' * len(valid_ids))
        query = f"SELECT paper_id, title, content FROM abstracts WHERE id IN ({placeholders})"
        
        self.cursor.execute(query, valid_ids)
        results = self.cursor.fetchall()
        
        return results

    def close(self):
        self.conn.close()

# --- Usage Example ---
if __name__ == "__main__":
    # Initialize
    rag = LightWeightRAG(db_path="my_papers.db")

    # 1. Mock Data (Your Google GenAI results go here)
    mock_docs = [
        {"paper_id": "A1", "title": "Deep Learning", "date": "2023", "content": "Abstract about DL..."},
        {"paper_id": "B2", "title": "Quantum Comp", "date": "2024", "content": "Abstract about QC..."}
    ]
    # Mock Embeddings (768 dimensions)
    mock_embeddings = [np.random.rand(768), np.random.rand(768)]

    # 2. Add to DB
    rag.add_documents(mock_docs, mock_embeddings)

    # 3. Search
    # (In real life, this comes from embed_content("search query"))
    mock_query = np.random.rand(768) 
    results = rag.search(mock_query, k=1)

    print("Retrieved Result:", results)
    rag.close()