### Imports and Path setup

In [1]:
# !pip install chromadb dotenv -q

In [2]:
from pathlib import Path
import chromadb
import pickle
import os
from dotenv import load_dotenv
load_dotenv()

multiquery_rag_output_path = "../RAG Results/multiquery_rag_results.txt"
Relative_Database_path = "./chroma_Data_with_BERT_embeddings"
Absolute_Database_path = Path(Relative_Database_path).resolve()
file_path = "../Chunking/Chunk_files/harry_potter_chunks_semantic.pkl"
# Create a new collection with a unique name
collection_name = "HP_Chunks_BERT_Embeddings_collection"
# Set API key
# os.environ["GOOGLE_API_KEY"] = os.environ.get("GEMINI_API_KEY")


### Chroma Setup and Chunk Loading
Sets up persistant client and loads previously computed chunks

In [3]:
# Initialize the persistent client
client = chromadb.PersistentClient(path=Absolute_Database_path)
print(f"[INFO] ChromaDB client initialized at: {Absolute_Database_path}")

# List existing collections
existing_collections = client.list_collections()
print(f"Existing collections: {[c.name for c in existing_collections]}")

[INFO] ChromaDB client initialized at: /home/tanish/ANLP_Proj/RAG_for_research_papers/VectorDB/chroma_Data_with_BERT_embeddings
Existing collections: ['HP_Chunks_BERT_Embeddings_collection']


In [4]:

# No need for fitz or RecursiveCharacterTextSplitter here, as we are loading from a file.


loaded_docs = []

try:
    with open(file_path, "rb") as f: # 'rb' mode for reading in binary
        loaded_docs = pickle.load(f)
    print(f"Successfully loaded {len(loaded_docs)} chunks from '{file_path}'.")
except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"Error loading file: {e}")

# Now you can inspect the loaded documents to verify.
print("\nHere is the metadata of a loaded chunk:")
if loaded_docs:
    print(loaded_docs[0].metadata)

Successfully loaded 4014 chunks from '../Chunking/Chunk_files/harry_potter_chunks_semantic.pkl'.

Here is the metadata of a loaded chunk:
{'source': '../harrypotter.pdf', 'page_number': 14, 'c': 'semantic', 'ischunk': True}


### Set up Embedding Function
Will use custom pre-trained BERT model to generate embeddings. Location for BERT is ../Encoder/saved_bert_encoder_moe_pooling

#### Recreate BERT Model 

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from typing import List, Union
import numpy as np

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Load vocab
MODEL_DIR = "../Encoder/saved_bert_encoder_moe_pooling"
with open(f"{MODEL_DIR}/vocab.json", "r") as f:
    vocab_data = json.load(f)
    stoi = vocab_data["stoi"]
    itos = vocab_data["itos"]
    
vocab_size = len(itos)
print(f"Loaded vocab with {vocab_size} tokens")

# Special tokens
PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"
SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# Model configuration (must match training config)
HIDDEN_SIZE = 768
NUM_LAYERS = 12
NUM_HEADS = 12
FFN_DIM = 3072
DROPOUT = 0.1
MAX_SEQ_LEN = 512  # Changed from 1024 to 512 to match saved model
MAX_POSITION_EMBEDDINGS = 512  # This is what the saved model was trained with

# -------------------------
# Recreate Model Architecture
# -------------------------

class MoE(nn.Module):
    def __init__(self, hidden_size, ffn_dim, num_experts=5, k=2, noise_std=1.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_dim = ffn_dim
        self.num_experts = num_experts
        self.k = k
        self.noise_std = noise_std
        
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, ffn_dim),
                nn.GELU(),
                nn.Linear(ffn_dim, hidden_size)
            ) for _ in range(num_experts)
        ])
        
        self.router = nn.Linear(hidden_size, num_experts)
    
    def forward(self, x, mask=None):
        B, S, H = x.size()
        logits = self.router(x)
        probs_all = F.softmax(logits, dim=-1)
        importance = probs_all.sum(dim=(0, 1))
        total_tokens = float(B * S)
        aux_loss = (self.num_experts * (importance / total_tokens).pow(2).sum())
        
        if self.training:
            noise = torch.randn_like(logits) * self.noise_std
            logits_noisy = logits + noise
        else:
            logits_noisy = logits
        
        topk_vals, topk_idx = torch.topk(logits_noisy, self.k, dim=-1)
        topk_weights = F.softmax(topk_vals, dim=-1)
        
        expert_outs = []
        for e in range(self.num_experts):
            expert_outs.append(self.experts[e](x))
        expert_stack = torch.stack(expert_outs, dim=2)
        
        device = x.device
        gating = torch.zeros(B, S, self.num_experts, device=device, dtype=x.dtype)
        flat_idx = topk_idx.view(-1, self.k)
        flat_w = topk_weights.view(-1, self.k)
        gating_flat = gating.view(-1, self.num_experts)
        rows = torch.arange(gating_flat.size(0), device=device).unsqueeze(1).expand(-1, self.k)
        gating_flat.scatter_(1, flat_idx, flat_w)
        gating = gating_flat.view(B, S, self.num_experts)
        
        out = torch.einsum('bse,bseh->bsh', gating, expert_stack)
        return out, aux_loss

class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1, moe_experts=5, moe_k=2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.ffn_moe = MoE(hidden_size, ffn_dim, num_experts=moe_experts, k=moe_k)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ffn_out, aux_loss = self.ffn_moe(x, mask)
        x = self.ln2(x + self.dropout(ffn_out))
        return x, aux_loss

class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, 
                 ffn_dim=FFN_DIM, max_position_embeddings=512, pad_token_id=0, moe_experts=5, moe_k=2):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.hidden_size = hidden_size
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)
        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_size, num_heads, ffn_dim, dropout=DROPOUT, 
                                   moe_experts=moe_experts, moe_k=moe_k) 
            for _ in range(num_layers)
        ])
        self.nsp_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size), 
            nn.Tanh(), 
            nn.Linear(hidden_size, 2)
        )
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))
    
    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()
        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = self.token_embeddings(ids) + self.position_embeddings(pos) + self.segment_embeddings(tt)
        x = self.emb_dropout(self.emb_ln(x))
        total_aux = 0.0
        for layer in self.layers:
            x, aux = layer(x, mask)
            total_aux = total_aux + aux
        return x, total_aux
    
    def get_pooled_embeddings(self, ids, mask=None, exclude_special=True, normalize=True):
        """
        Generate embeddings with mask-aware mean pooling
        """
        seq_out, _ = self.encode(ids, tt=None, mask=mask)
        
        if mask is None:
            mask = (ids != self.pad_token_id).long()
        
        # Mask-aware mean pooling
        mask_float = mask.unsqueeze(-1).to(seq_out.dtype)
        
        if exclude_special:
            # Exclude special tokens from pooling
            special_upper = len(SPECIAL_TOKENS)
            special_flags = (ids < special_upper).to(seq_out.dtype)
            mask_float = mask_float * (1.0 - special_flags.unsqueeze(-1))
        
        summed = (seq_out * mask_float).sum(dim=1)
        denom = mask_float.sum(dim=1).clamp(min=1e-9)
        pooled = summed / denom
        
        if normalize:
            pooled = F.normalize(pooled, p=2, dim=1)
        
        return pooled

# Load model with matching max_position_embeddings
print("Loading BERT model...")
model = BertEncoderModel(vocab_size, max_position_embeddings=MAX_POSITION_EMBEDDINGS, moe_experts=5, moe_k=2)
model.load_state_dict(torch.load(f"{MODEL_DIR}/bert_encoder_moe_pooling.pt", map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("Model loaded successfully!")

# Simple tokenizer class
class SimpleTokenizer:
    def __init__(self, stoi, itos, max_length=512):
        self.stoi = stoi
        self.itos = itos
        self.max_length = max_length
        self.pad_token_id = stoi[PAD_TOKEN]
        self.cls_token_id = stoi[CLS_TOKEN]
        self.sep_token_id = stoi[SEP_TOKEN]
        self.unk_token_id = stoi[UNK_TOKEN]
    
    def tokenize(self, text: str) -> List[int]:
        """Tokenize text to IDs"""
        tokens = text.strip().split()
        ids = [self.stoi.get(tok, self.unk_token_id) for tok in tokens]
        # Truncate if needed (reserve space for CLS and SEP)
        ids = ids[:self.max_length - 2]
        # Add CLS and SEP
        ids = [self.cls_token_id] + ids + [self.sep_token_id]
        return ids
    
    def __call__(self, texts: List[str], padding=True, max_length=None):
        """Batch tokenization"""
        if max_length is None:
            max_length = self.max_length
        
        all_ids = [self.tokenize(text) for text in texts]
        
        if padding:
            max_len = min(max(len(ids) for ids in all_ids), max_length)
            padded_ids = []
            attention_masks = []
            
            for ids in all_ids:
                # Truncate if needed
                ids = ids[:max_len]
                # Pad
                pad_len = max_len - len(ids)
                padded_ids.append(ids + [self.pad_token_id] * pad_len)
                attention_masks.append([1] * len(ids) + [0] * pad_len)
            
            return {
                'input_ids': torch.tensor(padded_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_masks, dtype=torch.long)
            }
        else:
            return {
                'input_ids': torch.tensor(all_ids, dtype=torch.long)
            }

tokenizer = SimpleTokenizer(stoi, itos, max_length=MAX_SEQ_LEN)
print("Tokenizer initialized!")

# Custom embedding function for ChromaDB
class MyBERTEmbeddingFunction:
    def __init__(self, model, tokenizer, device, batch_size=16):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.batch_size = batch_size
    
    def _embed_texts(self, texts: Union[str, List[str]]) -> List[List[float]]:
        """Internal method to generate embeddings for text(s)"""
        # Handle single string input
        if isinstance(texts, str):
            texts = [texts]
        
        all_embeddings = []
        
        # Process in batches
        for i in range(0, len(texts), self.batch_size):
            batch_texts = texts[i:i + self.batch_size]
            
            # Tokenize
            encoded = self.tokenizer(batch_texts, padding=True, max_length=MAX_SEQ_LEN)
            input_ids = encoded['input_ids'].to(self.device)
            attention_mask = encoded['attention_mask'].to(self.device)
            
            # Generate embeddings
            with torch.no_grad():
                embeddings = self.model.get_pooled_embeddings(
                    input_ids, 
                    mask=attention_mask, 
                    exclude_special=True, 
                    normalize=True
                )
            
            # Convert to list
            embeddings_list = embeddings.cpu().numpy().tolist()
            all_embeddings.extend(embeddings_list)
        
        return all_embeddings
    
    def __call__(self, input: List[str]) -> List[List[float]]:
        """
        Generate embeddings for a list of texts (used for documents)
        """
        return self._embed_texts(input)
    
    def embed_query(self, input = None, **kwargs) -> List[List[float]]:
        """
        Generate embedding for query text(s) (required by ChromaDB)
        Returns List[List[float]] to match ChromaDB's expected format
        """
        # Handle keyword argument 'input' 
        if input is None and 'input' in kwargs:
            input = kwargs['input']
        
        if input is None:
            raise ValueError("No input provided to embed_query")
        
        # Handle both string and list inputs
        # ChromaDB sometimes passes a list even for single queries
        if isinstance(input, str):
            input = [input]
        
        # Return embeddings as List[List[float]]
        embeddings = self._embed_texts(input)
        return embeddings

print("\n[SUCCESS] BERT model and tokenizer ready!")
print(f"Embedding dimension: {HIDDEN_SIZE}")
print(f"Max sequence length: {MAX_SEQ_LEN}")


Using device: cuda
Loaded vocab with 45706 tokens
Loading BERT model...
Model loaded successfully!
Tokenizer initialized!

[SUCCESS] BERT model and tokenizer ready!
Embedding dimension: 768
Max sequence length: 512
Model loaded successfully!
Tokenizer initialized!

[SUCCESS] BERT model and tokenizer ready!
Embedding dimension: 768
Max sequence length: 512


In [6]:
# Initialize the custom BERT embedding function
embedding_fn = MyBERTEmbeddingFunction(model, tokenizer, DEVICE, batch_size=16)

# Test the embedding function
print("Testing embedding function...")
test_texts = ["This is a test sentence.", "Another example text."]
test_embeddings = embedding_fn(test_texts)
print(f"Generated {len(test_embeddings)} embeddings")
print(f"Embedding shape: {len(test_embeddings[0])} dimensions")
print(f"First embedding (first 5 values): {test_embeddings[0][:5]}")
print("\n[SUCCESS] Embedding function ready for ChromaDB!")


Testing embedding function...
Generated 2 embeddings
Embedding shape: 768 dimensions
First embedding (first 5 values): [0.03787987679243088, -0.01927832141518593, 0.007840707898139954, 0.030228307470679283, 0.009925552643835545]

[SUCCESS] Embedding function ready for ChromaDB!
Generated 2 embeddings
Embedding shape: 768 dimensions
First embedding (first 5 values): [0.03787987679243088, -0.01927832141518593, 0.007840707898139954, 0.030228307470679283, 0.009925552643835545]

[SUCCESS] Embedding function ready for ChromaDB!


### Create Collection with BERT Embeddings

In [7]:
from datetime import datetime

# FORCE DELETE the collection if it exists
try:
    client.delete_collection(name=collection_name)
    print(f"[INFO] Deleted existing collection '{collection_name}'")
except Exception as e:
    print(f"[INFO] No existing collection named '{collection_name}' to delete.")

# Create a FRESH collection with BERT embedding function
collection = client.create_collection(
    name=collection_name,
    embedding_function=embedding_fn,
    metadata={
        "description": "Harry Potter Chunks with custom BERT embeddings (MoE + Mask-aware pooling)",
        "created": str(datetime.now()),
        "model": "Custom BERT with MoE",
        "embedding_dim": HIDDEN_SIZE
    }
)

print(f"[SUCCESS] Fresh collection '{collection_name}' created successfully")
print(f"Current count in collection: {collection.count()}")


[INFO] Deleted existing collection 'HP_Chunks_BERT_Embeddings_collection'
[SUCCESS] Fresh collection 'HP_Chunks_BERT_Embeddings_collection' created successfully
Current count in collection: 0


### Add Documents to Collection
Prepare and add all chunks with BERT-generated embeddings

In [8]:
# Prepare data for ChromaDB
documents = []
metadatas = []
ids = []

for idx, doc in enumerate(loaded_docs):
    documents.append(doc.page_content)
    metadatas.append(doc.metadata)
    ids.append(f"hp_chunk_{idx}")

print(f"[INFO] Prepared {len(documents)} documents for embedding")
print(f"Sample document: {documents[0][:100]}...")
print(f"Sample metadata: {metadatas[0]}")


[INFO] Prepared 4014 documents for embedding
Sample document: . yes, that would be it. The traffic moved on and a few minutes
later, Mr. Dursley arrived in the Gr...
Sample metadata: {'source': '../harrypotter.pdf', 'page_number': 14, 'c': 'semantic', 'ischunk': True}


In [9]:
# Add documents to collection in batches
# ChromaDB will automatically call our embedding_fn to generate embeddings
batch_size = 500
total_batches = (len(documents) + batch_size - 1) // batch_size

print(f"[INFO] Adding documents in {total_batches} batches...")

for i in range(0, len(documents), batch_size):
    batch_docs = documents[i:i+batch_size]
    batch_metas = metadatas[i:i+batch_size]
    batch_ids = ids[i:i+batch_size]
    
    collection.add(
        documents=batch_docs,
        metadatas=batch_metas,
        ids=batch_ids
    )
    
    batch_num = (i // batch_size) + 1
    print(f"  Batch {batch_num}/{total_batches} added ({len(batch_docs)} documents)")

print(f"\n[SUCCESS] All documents added!")
print(f"Total documents in collection: {collection.count()}")


[INFO] Adding documents in 9 batches...
  Batch 1/9 added (500 documents)
  Batch 1/9 added (500 documents)
  Batch 2/9 added (500 documents)
  Batch 2/9 added (500 documents)
  Batch 3/9 added (500 documents)
  Batch 3/9 added (500 documents)
  Batch 4/9 added (500 documents)
  Batch 4/9 added (500 documents)
  Batch 5/9 added (500 documents)
  Batch 5/9 added (500 documents)
  Batch 6/9 added (500 documents)
  Batch 6/9 added (500 documents)
  Batch 7/9 added (500 documents)
  Batch 7/9 added (500 documents)
  Batch 8/9 added (500 documents)
  Batch 8/9 added (500 documents)
  Batch 9/9 added (14 documents)

[SUCCESS] All documents added!
Total documents in collection: 4014
  Batch 9/9 added (14 documents)

[SUCCESS] All documents added!
Total documents in collection: 4014


### Test the Collection
Query the collection to verify embeddings are working correctly

In [None]:
# Test query
test_query = "Who is Harry Potter?"

print(f"Test Query: '{test_query}'")
print("\nSearching with BERT embeddings...")

results = collection.query(
    query_texts=[test_query],
    n_results=5    
)

print(f"\nTop 5 Results:")
for idx, (doc, distance) in enumerate(zip(results['documents'][0], results['distances'][0])):
    print(f"\n{idx+1}. Distance: {distance:.4f}")
    print(f"   Text: {doc[:150]}...")

print("\n[SUCCESS] Collection is working correctly with BERT embeddings!")


Test Query: 'Who is Harry Potter?'

Searching with BERT embeddings...

Top 5 Results:

1. Distance: 1.5522
   Text: “Oh come on,” he said impatiently, “we need partners, we’re going to look
really stupid if we haven’t got any, everyone else has . . .”
“I can’t come ...

2. Distance: 1.5566
   Text: . see Dumbledore . . . my fault . . . all
my fault . . . Bertha . . . dead . . . all my fault . . . my son . . . my fault . . . tell
Dumbledore . . . ...

3. Distance: 1.5804
   Text: “Fine,” said Harry stiffly. “Oh, don’t lie, Harry,” she said impatiently. “Ron and Ginny say you’ve
been hiding from everyone since you got back from ...

4. Distance: 1.5882
   Text: “Why can’t they help?”
“What?”
“They can help.” He dropped his voice and said, so that none of them could
hear but Hermione, who stood between them, “...

5. Distance: 1.5886
   Text: . I don’t believe it . . . he crept up behind
me. . . . I heard him, I turned around, and he had his wand on me. . . .”
Cedric got up. He was still 

: 