In [32]:
# Make sure this notebook is running on the GPU
import torch
from tqdm import tqdm
from model_for_db import DocTower
import chromadb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")



Using device: cuda


In [37]:
# this tokenizer should just return the length of the document
def simple_tokenizer(text):
    # split on whitespace
    words = text.split()
    len_words = len(words)
    return len_words

In [39]:
# Embed the documents
def embed_docs(model, documents, batch_size, device, collection=None, start_id=0):

    """
    Embed a list of texts using a pre-trained model.
    
    Args:
        model: The pre-trained doc tower model to use for embedding.
        documents: The list of texts to embed.
    """

    model.to(device)
    model.eval()

    total_documents = len(documents)
    # Create progress bar for the total number of batches
    total_batches = (total_documents + batch_size - 1) // batch_size
    progress_bar = tqdm(range(total_batches), desc="Processing batches")

    with torch.no_grad():
        for i in range(0, total_documents, batch_size):

            # Get batch of data
            batch_passages = documents[i:i + batch_size] 
            
            # Calculate lengths for current batch only
            batch_lengths = [simple_tokenizer(doc) for doc in batch_passages]

            doc_embeds = model(batch_passages, batch_lengths)
            batch_embeddings = doc_embeds.cpu().numpy()
            
            batch_ids = [str(start_id + i + j) for j in range(len(batch_passages))]
            
            
            print('Documents embedded. Storing to ChromaDB...')
            # Add batch directly to ChromaDB collection
            collection.add(
                documents=batch_passages,
                ids=batch_ids,
                embeddings=batch_embeddings
            )
            
            # Update progress bar
            progress_bar.update()
            progress_bar.set_postfix({"Processed": f"{i+batch_size}/{len(documents)} passages"})

    print('Documents embedded and stored in ChromaDB.')
    return 
        
    


In [40]:
# Load the list of documents
# Prepare the document dataset
# Get the dataset from cocoritz
# Combine the positive and negative passages into a single documents dataset
from datasets import load_dataset
df_sn = load_dataset("cocoritzy/week_2_triplet_dataset_soft_negatives")
df_sn = df_sn["train"].to_pandas()
df_sn.head()

# Create a list of documents from all values in the positive and negative columns 
print(len(df_sn['positive_passage'].tolist()))
print(len(df_sn['negative_passage'].tolist()))
all_passages = df_sn['positive_passage'].tolist() + df_sn['negative_passage'].tolist()
print(len(all_passages))

79704
79704
159408


### USAGE

In [41]:
# Load the complete statedict
state_dict = torch.load("two_tower_model_GRU_padding.pt", map_location=device)

# Extract only the DocTower parameters
doc_tower_state = state_dict['docTower']

model = DocTower()
model.load_state_dict(doc_tower_state)
model.eval()

# Initialize ChromaDB client
client = chromadb.PersistentClient(path="./chroma_db")
# Create or access a collection
collection_name = 'marco_sn_documents'
collection = client.get_or_create_collection(name=collection_name)

# Determine how many documents are already in the collection
existing_docs_count = collection.count()
print(f"Collection already contains {existing_docs_count} documents")


Collection already contains 0 documents


In [42]:
# Embed only new documents or all (depending on your use case)
embed_docs(model, all_passages, batch_size=256, device=device, collection=collection, start_id=existing_docs_count)



TypeError: 'descending' is an invalid keyword argument for sort()