In [7]:
!pip install sentence-transformers pymongo tqdm -q

In [8]:
import os
from datetime import datetime, timezone
import sys
import time
import math
from tqdm.auto import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
import pymongo
import torch

In [None]:
# MongoDB connection - prefer environment variable for secrets
MONGO_URI = os.environ.get("MONGO_URI") or "mongodb+srv://<USER>:<PASSWORD>@cluster0.dz0qh3l.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
MONGO_DB = os.environ.get("MONGO_DB") or "medicrew"
MONGO_COLLECTION = os.environ.get("MONGO_COLLECTION") or "paper_chunks"

# Model
MODEL_NAME = os.environ.get("EMBED_MODEL") or "neuml/pubmedbert-base-embeddings"

# Processing parameters
BATCH_SIZE = 64 # embedding batch size (GPU allowed higher)
FETCH_BATCH = 500 # how many Mongo docs to fetch per loop
MAX_TO_PROCESS = None # set e.g. 1000 to test quickly, or None to process all

# Safety / options
NORMALIZE = True # sentence-transformers supports normalize_embeddings=True
UPDATE_EMBED_FLAG = True # set False to only compute embeddings without updating DB

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Device: cuda


In [17]:
# %%
# Connect to MongoDB and sanity-check
client = pymongo.MongoClient(MONGO_URI)
db = client[MONGO_DB]
collection = db[MONGO_COLLECTION]

# Count to process
to_process_count = collection.count_documents({'embedded': False})
print(f"Documents flagged for embedding (embedded=False): {to_process_count:,}")
if MAX_TO_PROCESS:
  print(f"MAX_TO_PROCESS limit is on: {MAX_TO_PROCESS}")

Documents flagged for embedding (embedded=False): 33,094


In [18]:
# %%
# Load the sentence-transformers model
print(f"Loading embedding model: {MODEL_NAME} ...")
embed_model = SentenceTransformer(MODEL_NAME, device=device)
# guard the model's max seq length
try:
  embed_model.max_seq_length = min(getattr(embed_model, 'max_seq_length', 512), 512)
except Exception:
  pass
print("Model loaded. Embedding dimension:", embed_model.get_sentence_embedding_dimension())

Loading embedding model: neuml/pubmedbert-base-embeddings ...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/123 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/667 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Model loaded. Embedding dimension: 768


In [19]:
# %%
# Helper: embedding function (wraps SentenceTransformer.encode)
def embed_texts(texts, batch_size=BATCH_SIZE, normalize=NORMALIZE):
  # texts: list of str
  # returns: numpy array shape (n, dim)
  embs = embed_model.encode(
  texts,
  batch_size=batch_size,
  show_progress_bar=False,
  convert_to_numpy=True,
  normalize_embeddings=normalize
  )
  return embs

# Helper: safe conversion to python list for Mongo insertion
def as_list(vec):
  if isinstance(vec, np.ndarray):
    return vec.tolist()
  if isinstance(vec, list):
    return vec
  return list(vec)

In [20]:
# %%
# Main processing loop - fetch batches from Mongo, embed, update
processed = 0
start_time = time.time()

# We will page through documents in chunks of FETCH_BATCH to avoid long cursors
with tqdm(total=(MAX_TO_PROCESS or to_process_count), desc="Embedding docs") as pbar:
    while True:
        # fetch a batch of documents not embedded
        cursor = collection.find({'embedded': False}, {'_id': 1, 'text': 1}).limit(FETCH_BATCH)
        docs = list(cursor)
        if not docs:
            break

        # If MAX_TO_PROCESS is set, trim
        if MAX_TO_PROCESS:
            remaining = MAX_TO_PROCESS - processed
            if remaining <= 0:
                break
            if len(docs) > remaining:
                docs = docs[:remaining]

        # Process in smaller sub-batches to control memory
        ids = [d['_id'] for d in docs]
        texts = [d.get('text','') for d in docs]

        try:
            embeddings = embed_texts(texts, batch_size=BATCH_SIZE)
        except Exception as e:
            print("Embedding error - retrying in smaller batches:", e)
            # fallback: smaller batches
            embeddings = []
            for i in range(0, len(texts), max(1, BATCH_SIZE//4)):
                sub = texts[i:i+max(1, BATCH_SIZE//4)]
                sub_emb = embed_texts(sub, batch_size=max(1, BATCH_SIZE//4))
                embeddings.append(sub_emb)
            embeddings = np.vstack(embeddings)

        # Prepare bulk updates
        bulk_ops = []
        now = datetime.now(timezone.utc)
        for doc_id, emb in zip(ids, embeddings):
            emb_list = as_list(emb)
            if UPDATE_EMBED_FLAG:
                bulk_ops.append(
                    pymongo.UpdateOne(
                        {'_id': doc_id},
                        {'$set': {'embedding': emb_list, 'embedded': True, 'embedded_at': now}}
                    )
                )

        if bulk_ops:
            # do bulk write
            try:
                result = collection.bulk_write(bulk_ops, ordered=False)
            except Exception as e:
                print("Bulk write error:", e)
                # fallback: try item-wise
                for op in bulk_ops:
                    try:
                        collection.update_one(op._filter, op._update)
                    except Exception as e2:
                        print("Single update failed for", op._filter, e2)

        processed += len(docs)
        pbar.update(len(docs))

        # If MAX_TO_PROCESS limit reached, stop
        if MAX_TO_PROCESS and processed >= MAX_TO_PROCESS:
            break

print(f"Done. Processed {processed} docs in {time.time()-start_time:.1f}s")

Embedding docs:   0%|          | 0/33094 [00:00<?, ?it/s]

Done. Processed 33094 docs in 1457.0s


In [21]:
# %%
# Quick sanity check: sample a few embedded docs
sample = list(collection.find({'embedded': True}, {'_id':1, 'embedding':1}).limit(5))
print("Sample embedded docs:", len(sample))

# %%
# Save or export: optionally write some stats to a local CSV for analysis
try:
    import pandas as pd
    stats = {
        'processed': processed,
        'model': MODEL_NAME,
        'time_s': time.time()-start_time
    }
    pd.DataFrame([stats]).to_csv('embed_run_stats.csv', index=False)
    print('Saved run stats to embed_run_stats.csv')
except Exception:
    pass

# %%
# END
print('Notebook finished. Review the collection for "embedded" flags and embeddings length.')


Sample embedded docs: 5
Saved run stats to embed_run_stats.csv
Notebook finished. Review the collection for "embedded" flags and embeddings length.
