# Re-Embedding Chunks in ChromaDB with a Fine-Tuned SentenceTransformer
<img src="https://raw.githubusercontent.com/CNUClasses/CPSC471/master/content/lectures/week13/stage2.png" alt="standard" style="max-height:300px;  margin:10px 0; vertical-align:middle;">

This notebook demonstrates how to:

1. **Load a fine-tuned embedding model** from a local directory:
   - `/models/sentence-transformers/msmarco-distilbert-cos-v5`
2. **Connect to an existing ChromaDB persistent database**:
   - `PERSIST_DIR = "./rag_chroma"`
   - Existing collection: `cnu_rag_lab`
3. **Re-embed all documents (chunks) from the existing collection** using the fine-tuned model.
4. **Create a new ChromaDB collection** with updated embeddings:
   - New collection name: `cnu_rag_mnrl_lab`

This is useful when you have improved your embedding model (e.g., via fine-tuning with MNRL)
and want your retrieval index (Chroma) to use the new embeddings without changing the
underlying text or metadata.


In [1]:
# ============================================================================
# (Optional) Install Dependencies
# ============================================================================
# Uncomment and run this cell if you do NOT already have these libraries
# installed in your Python environment.
#
# - chromadb: for the vector database
# - sentence-transformers: for the embedding model
# ----------------------------------------------------------------------------

# !pip install -U chromadb sentence-transformers

In [2]:
# ============================================================================
# Imports and Configuration
# ============================================================================
# This section imports all required libraries and defines configuration
# variables such as model path, Chroma persistence directory, and collection
# names.
# ----------------------------------------------------------------------------

# import os
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'  # Specify which GPUs to use
from typing import List, Dict

import utils as ut
import torch
from sentence_transformers import SentenceTransformer

import chromadb
from chromadb.config import Settings

# Path to the fine-tuned SentenceTransformer model.
# This should be a directory containing the saved model files.
FINETUNED_MODEL_PATH = "./models/sentence-transformers/msmarco-distilbert-cos-v5"

# ChromaDB persistence directory (where the DB is stored on disk)
PERSIST_DIR = "../week12/rag_chroma"

# Name of the existing collection with old embeddings
SOURCE_COLLECTION_NAME = "cnu_rag_lab"

# Name of the new collection where we will store re-embedded chunks
TARGET_COLLECTION_NAME = "cnu_rag_mnrl_lab"

# Device configuration: use GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)




Using device: cuda


In [3]:
# ============================================================================
# Load the Fine-Tuned Embedding Model
# ============================================================================
# We now load the fine-tuned SentenceTransformer model from the specified
# local directory. This model will be used to compute new embeddings for
# all documents in the ChromaDB collection.
# ----------------------------------------------------------------------------

# Load the model from the local path
model = SentenceTransformer(FINETUNED_MODEL_PATH)
model.to(device)

# Quick sanity check: print out the model's max sequence length and embedding dimension
print("Model loaded from:", FINETUNED_MODEL_PATH)
print("Max sequence length:", model.get_max_seq_length())
print("Embedding dimension:", model.get_sentence_embedding_dimension())


Model loaded from: ./models/sentence-transformers/msmarco-distilbert-cos-v5
Max sequence length: 384
Embedding dimension: 768


In [4]:
# ============================================================================
# Connect to ChromaDB and Access the Source Collection
# ============================================================================
# We connect to the persistent ChromaDB instance stored in PERSIST_DIR.
# Then we open the existing collection that contains the original embeddings
# and chunks.
# ----------------------------------------------------------------------------

# Create a ChromaDB client that uses the persistent directory
client = chromadb.PersistentClient(path=PERSIST_DIR)
collection = client.get_collection(SOURCE_COLLECTION_NAME)

print("Connected to ChromaDB with persist_directory =", PERSIST_DIR)

# Get the source collection which already exists and contains original embeddings
source_collection = client.get_collection(SOURCE_COLLECTION_NAME)
print("Loaded source collection:", source_collection.name)


Connected to ChromaDB with persist_directory = ../week12/rag_chroma
Loaded source collection: cnu_rag_lab


In [5]:
# ============================================================================
# Fetch All Documents (Chunks) from the Source Collection
# ============================================================================
# ChromaDB returns results in batches via `.get()`, so we may need to paginate
# through the collection if it is large. Here, we:
#
#   1. Determine how many items are in the collection.
#   2. Fetch them in chunks (batches) using limit/offset.
#   3. Aggregate all 'ids', 'documents', and 'metadatas' in memory.
#
# We purposely ignore the existing embeddings, as we plan to recompute them.
# ----------------------------------------------------------------------------

# First, get a small batch to estimate total number of items in the collection
sample = source_collection.get(limit=1)
total_count = sample.get("count", None)

# Some versions of Chroma may not return 'count'; if so, we can approximate by
# asking for a large batch and checking how many we get.
if total_count is None:
    sample_large = source_collection.get(limit=1000000)
    total_count = len(sample_large["ids"])

print("Total number of items in source collection:", total_count)

# Define batch size for pagination
BATCH_SIZE = 1000

all_ids = []
all_documents = []
all_metadatas = []

offset = 0

while offset < total_count:
    # Fetch a batch of documents from the collection
    batch = source_collection.get(
        limit=BATCH_SIZE,
        offset=offset,
        include=["documents", "metadatas"],
    )
    
    ids = batch["ids"]
    docs = batch["documents"]
    metas = batch["metadatas"]
    
    all_ids.extend(ids)
    all_documents.extend(docs)
    all_metadatas.extend(metas)
    
    offset += len(ids)
    print(f"Fetched {len(ids)} items (offset now {offset})")

print("Total fetched IDs:", len(all_ids))
print("Total fetched documents:", len(all_documents))
print("Total fetched metadatas:", len(all_metadatas))

# Quick sanity check
if len(all_ids) != total_count:
    print("Warning: fetched item count does not match expected total_count.")


Total number of items in source collection: 1256
Fetched 1000 items (offset now 1000)
Fetched 256 items (offset now 1256)
Total fetched IDs: 1256
Total fetched documents: 1256
Total fetched metadatas: 1256


In [6]:
# # ============================================================================
# # Helper Function: Encode Documents in Batches
# # ============================================================================
# # We define a small helper function that takes a list of texts (documents) and
# # uses the SentenceTransformer model to compute embeddings in batches. This
# # avoids running out of memory when there are many documents.
# # ----------------------------------------------------------------------------

# from math import ceil

# def encode_documents(
#     texts: List[str],
#     model: SentenceTransformer,
#     batch_size: int = 64,
# ) -> List[List[float]]:
#     """Encode a list of documents into embeddings using the given model.

#     This function:
#       - Splits the input texts into batches.
#       - Uses `model.encode` with `convert_to_numpy=True` and `normalize_embeddings=True`
#         for cosine similarity–friendly vectors.
#       - Converts the resulting numpy arrays into Python lists (for ChromaDB).

#     Args:
#         texts: List of string documents to encode.
#         model: A SentenceTransformer model.
#         batch_size: Number of documents to encode per batch.

#     Returns:
#         A list of embeddings, where each embedding is a list[float].
#     """
#     all_embeddings: List[List[float]] = []
#     num_texts = len(texts)
#     num_batches = ceil(num_texts / batch_size)

#     for i in range(num_batches):
#         start = i * batch_size
#         end = min(start + batch_size, num_texts)
#         batch_texts = texts[start:end]

#         # Compute embeddings for this batch
#         # - convert_to_numpy=True returns a numpy array
#         # - normalize_embeddings=True L2-normalizes the embeddings
#         batch_embeddings = model.encode(
#             batch_texts,
#             batch_size=batch_size,
#             convert_to_numpy=True,
#             normalize_embeddings=True,
#             show_progress_bar=False,
#         )

#         # Convert each embedding vector to a plain Python list
#         for emb in batch_embeddings:
#             all_embeddings.append(emb.tolist())

#         # Optional logging for long runs
#         if (i + 1) % 10 == 0 or (i + 1) == num_batches:
#             print(f"Encoded batch {i + 1}/{num_batches} (docs {start}–{end})")

#     return all_embeddings


In [7]:
# ============================================================================
# Encode All Documents with the Fine-Tuned Model
# ============================================================================
# We now use the helper function `encode_documents` to compute new embeddings
# for every document in the source collection using the fine-tuned model.
# ----------------------------------------------------------------------------

print("Starting re-embedding of all documents...")

new_embeddings = ut.encode_documents(
    texts=all_documents,
    model=model,
    batch_size=64,
)

print("Re-embedding complete.")
print("Number of embeddings:", len(new_embeddings))


Starting re-embedding of all documents...
Encoded batch 10/20 (docs 576–640)
Encoded batch 20/20 (docs 1216–1256)
Re-embedding complete.
Number of embeddings: 1256


In [8]:
# ============================================================================
# Create the Target Collection and Insert Re-Embedded Documents
# ============================================================================
# We create a new Chroma collection named `cnu_rag_mnrl_lab` and insert:
#   - the original IDs
#   - the original documents (chunks)
#   - the original metadatas
#   - the **new** embeddings computed by the fine-tuned model
#
# If a collection with this name already exists, we delete it first to avoid
# mixing old data.
# ----------------------------------------------------------------------------

# If the target collection already exists, delete it for a clean re-index
existing_collections = [c.name for c in client.list_collections()]
if TARGET_COLLECTION_NAME in existing_collections:
    print(f"Target collection '{TARGET_COLLECTION_NAME}' already exists. Deleting it first...")
    client.delete_collection(TARGET_COLLECTION_NAME)

# Create the new target collection
target_collection = client.create_collection(name=TARGET_COLLECTION_NAME)
print("Created target collection:", target_collection.name)

# Insert the documents, metadatas, and new embeddings into the target collection
target_collection.add(
    ids=all_ids,
    documents=all_documents,
    metadatas=all_metadatas,
    embeddings=new_embeddings,
)

print("Inserted", len(all_ids), "items into target collection:", TARGET_COLLECTION_NAME)


Target collection 'cnu_rag_mnrl_lab' already exists. Deleting it first...
Created target collection: cnu_rag_mnrl_lab


Inserted 1256 items into target collection: cnu_rag_mnrl_lab


In [9]:
# ============================================================================
# Persist Changes to Disk
# ============================================================================
# For persistent ChromaDB setups, we call `client.persist()` to ensure
# that all changes (including the new collection and its embeddings)
# are written to disk under PERSIST_DIR.
# ----------------------------------------------------------------------------

# Persist the current state of the database to disk
# client.persist()
# print("ChromaDB changes persisted to disk at:", PERSIST_DIR)
