In [1]:
from pprint import pprint
from typing import List
from FlagEmbedding import BGEM3FlagModel
from llama_index.core.base.embeddings.base import Embedding
from llama_index.core.embeddings import BaseEmbedding
import numpy as np
model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True)

model.encode("Hello World", batch_size=12, max_length=8192)["dense_vecs"]

  from .autonotebook import tqdm as notebook_tqdm
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 23436.23it/s]
You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


array([-0.05627,  0.02858, -0.01721, ...,  0.02463, -0.0355 ,  0.0144 ],
      dtype=float16)

In [2]:
from llama_index.core.node_parser.text import SentenceSplitter
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=16)

In [3]:
import datasets
from tqdm import tqdm

# https://huggingface.co/datasets/hotpotqa/hotpot_qa?row=16
# https://arxiv.org/pdf/1606.05250
dataset = datasets.load_dataset("rajpurkar/squad")

# Split the context into chunks
full_chunks = []
for i in tqdm(range(len(dataset["train"]))):
    row = dataset["train"][i]
    chunks = text_splitter.split_text(row["context"])
    full_chunks.extend([
        {
            "id": f"{row['id']}_{i}",
            "chunk": chunks[i],
        }
        for i in range(len(chunks))
    ])

print("Number of chunks:", len(full_chunks))
print("Original length:", len(dataset["train"]))
full_chunks[:10]


100%|██████████| 87599/87599 [00:19<00:00, 4510.40it/s]

Number of chunks: 158985
Original length: 87599





[{'id': '5733be284776f41900661182_0',
  'chunk': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection.'},
 {'id': '5733be284776f41900661182_1',
  'chunk': 'It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.'},
 {'id': '5733be284776f4190066117f_0',
  'chunk': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in fro

In [4]:
# Get the embeddings
embeddings = model.encode([chunk["chunk"] for chunk in full_chunks], batch_size=12, max_length=8192)["dense_vecs"]

print(len(embeddings))
embeddings[:10]

Inference Embeddings: 100%|██████████| 13249/13249 [19:13<00:00, 11.48it/s]


158985


array([[ 0.03357 , -0.001778, -0.0349  , ...,  0.03693 , -0.0222  ,
        -0.000498],
       [ 0.007965,  0.01884 , -0.0699  , ...,  0.03506 , -0.03018 ,
        -0.02544 ],
       [ 0.03357 , -0.001778, -0.0349  , ...,  0.03693 , -0.0222  ,
        -0.000498],
       ...,
       [ 0.007965,  0.01884 , -0.0699  , ...,  0.03506 , -0.03018 ,
        -0.02544 ],
       [ 0.03357 , -0.001778, -0.0349  , ...,  0.03693 , -0.0222  ,
        -0.000498],
       [ 0.007965,  0.01884 , -0.0699  , ...,  0.03506 , -0.03018 ,
        -0.02544 ]], dtype=float16)

In [5]:
# Save embeddings to file as a checkpoint
np.save("squad_embeddings.npy", embeddings)


In [6]:
from chromadb import Documents, EmbeddingFunction, Embeddings
class MyEmbeddingFunction(EmbeddingFunction):
    def __call__(self, input: Documents) -> Embeddings:
        embeddings = model.encode(input, batch_size=12, max_length=8192)["dense_vecs"]
        return embeddings


In [7]:
from chromadb import Client

chroma_client = Client()
collection = chroma_client.create_collection(name="squad2", embedding_function=MyEmbeddingFunction())
batch_size = 1000
for i in tqdm(range(0, len(embeddings), batch_size)):
    end_idx = min(i + batch_size, len(embeddings))  # Ensure we don't go past the end
    collection.add(
        embeddings=embeddings[i:end_idx],
        documents=[chunk["chunk"] for chunk in full_chunks[i:end_idx]],
        ids=[chunk["id"] for chunk in full_chunks[i:end_idx]]
    )

100%|██████████| 159/159 [01:24<00:00,  1.89it/s]


In [8]:
# Search collection
collection.query(
    query_texts=["What is the capital of France?"],
    n_results=10
)

{'ids': [['56dcddb066d3e219004dab4b_0',
   '56dcddb066d3e219004dab47_0',
   '56dcddb066d3e219004dab4a_0',
   '56dcddb066d3e219004dab49_0',
   '56dcddb066d3e219004dab48_0',
   '570d7002b3d812140066d934_1',
   '570d7002b3d812140066d931_1',
   '570d7002b3d812140066d932_1',
   '570d7002b3d812140066d935_1',
   '570d7002b3d812140066d933_1']],
 'embeddings': None,
 'documents': [["The area north of the Congo River came under French sovereignty in 1880 as a result of Pierre de Brazza's treaty with Makoko of the Bateke. This Congo Colony became known first as French Congo, then as Middle Congo in 1903. In 1908, France organized French Equatorial Africa (AEF), comprising Middle Congo, Gabon, Chad, and Oubangui-Chari (the modern Central African Republic). The French designated Brazzaville as the federal capital. Economic development during the first 50 years of colonial rule in Congo centered on natural-resource extraction.",
   "The area north of the Congo River came under French sovereignty in 

In [9]:
# Calculate precision for one example
# By checking document ids
precision = 0
i = 10
row = dataset["train"][i]
query = row["question"]
results = collection.query(query_texts=[query], n_results=10)
ids = results["ids"][0]
ground_truth = row["id"]
original_ids = [id.split("_")[0] for id in ids]
precision = sum([1 for id in original_ids if id == ground_truth]) / len(original_ids)
pprint(query)
pprint(results)
pprint(ground_truth)
pprint(original_ids)
pprint(precision)
# This doesn't work because there are duplicate contexts in the dataset


'Where is the headquarters of the Congregation of the Holy Cross?'
{'data': None,
 'distances': [[0.8609922528266907,
                0.8609922528266907,
                0.8609922528266907,
                0.8609922528266907,
                0.8609922528266907,
                0.9223301410675049,
                0.9223301410675049,
                0.9223301410675049,
                0.9223301410675049,
                0.9223301410675049]],
 'documents': [['The university is the major seat of the Congregation of Holy '
                'Cross (albeit not its official headquarters, which are in '
                'Rome). Its main seminary, Moreau Seminary, is located on the '
                'campus across St. Joseph lake from the Main Building. Old '
                'College, the oldest building on campus and located near the '
                'shore of St. Mary lake, houses undergraduate seminarians. '
                'Retired priests and brothers reside in Fatima House (a former '
     

In [10]:
# Let's check if the substring is in the document
i = 100
row = dataset["train"][i]
query = row["question"]
answer = row["answers"]["text"][0]
results = collection.query(query_texts=[query], n_results=10)
precision = sum([1 for result in results["documents"][0] if answer in result]) / len(results["documents"][0])
pprint(query)
pprint("answer: " + answer)
pprint(results)
pprint("precision: " + str(precision))


'In what year did the team lead by Knute Rockne win the Rose Bowl?'
'answer: 1925'
{'data': None,
 'distances': [[1.2871770858764648,
                1.2871770858764648,
                1.2871770858764648,
                1.2871770858764648,
                1.2871770858764648,
                1.2901089191436768,
                1.2901089191436768,
                1.2901089191436768,
                1.2901089191436768,
                1.2923610210418701]],
 'documents': [['By the early 1980s, Queen were one of the biggest stadium '
                "rock bands in the world. Their performance at 1985's Live Aid "
                'is ranked among the greatest in rock history by various music '
                'publications, with a 2005 industry poll ranking it the best. '
                'In 1991, Mercury died of bronchopneumonia, a complication of '
                'AIDS, and Deacon retired in 1997. Since then, May and Taylor '
                'have occasionally performed together, includ

In [11]:
precisions = [] 
for i in tqdm(range(len(dataset["train"]))):
    row = dataset["train"][i]
    query = row["question"]
    answer = row["answers"]["text"][0]
    results = collection.query(query_texts=[query], n_results=10)
    precision = sum([1 for result in results["documents"][0] if answer in result]) / len(results["documents"][0])
    precisions.append(precision)

np.mean(precisions)

 15%|█▍        | 12769/87599 [19:18<1:53:09, 11.02it/s]


KeyboardInterrupt: 

In [47]:
# Super slow let's try a few things:
# 1. query in batches
# 2. Update the precision count via a set
# Batch process queries
batch_size = 32  # Adjust based on your memory constraints
precisions = []

for i in tqdm(range(0, len(dataset["train"]), batch_size)):
    # Get batch of questions and answers
    batch_slice = slice(i, min(i + batch_size, len(dataset["train"])))
    batch_rows = dataset["train"][batch_slice]
    batch_queries = batch_rows["question"]
    batch_answers = [ans["text"][0] for ans in batch_rows["answers"]]
    
    # Batch query ChromaDB
    results = collection.query(
        query_texts=batch_queries,
        n_results=10
    )
    
    # Calculate precision for each query in batch
    for query_idx, (answer, documents) in enumerate(zip(batch_answers, results["documents"])):
        # Convert to set for faster lookup
        answer_matches = sum(1 for doc in documents if answer in doc)
        precision = answer_matches / len(documents)
        precisions.append(precision)

mean_precision = np.mean(precisions)

100%|██████████| 2738/2738 [06:36<00:00,  6.90it/s]


In [48]:
mean_precision
# Answer they got was 0.1246

np.float64(0.23560885398235137)

In [49]:
batch_size = 32  # Adjust based on your memory constraints
precisions = []

split = "validation"

for i in tqdm(range(0, len(dataset[split]), batch_size)):
    # Get batch of questions and answers
    batch_slice = slice(i, min(i + batch_size, len(dataset[split])))
    batch_rows = dataset[split][batch_slice]
    batch_queries = batch_rows["question"]
    batch_answers = [ans["text"][0] for ans in batch_rows["answers"]]
    
    # Batch query ChromaDB
    results = collection.query(
        query_texts=batch_queries,
        n_results=10
    )
    
    # Calculate precision for each query in batch
    for query_idx, (answer, documents) in enumerate(zip(batch_answers, results["documents"])):
        # Convert to set for faster lookup
        answer_matches = sum(1 for doc in documents if answer in doc)
        precision = answer_matches / len(documents)
        precisions.append(precision)

mean_precision = np.mean(precisions)
mean_precision

100%|██████████| 331/331 [00:44<00:00,  7.43it/s]


In [50]:
mean_precision

np.float64(0.026887417218543045)

In [52]:
from datasets import concatenate_datasets
batch_size = 32  # Adjust based on your memory constraints
precisions = []

joined_dataset = concatenate_datasets([dataset["train"], dataset["validation"]])

for i in tqdm(range(0, len(joined_dataset), batch_size)):
    # Get batch of questions and answers
    batch_slice = slice(i, min(i + batch_size, len(joined_dataset)))
    batch_rows = joined_dataset[batch_slice]
    batch_queries = batch_rows["question"]
    batch_answers = [ans["text"][0] for ans in batch_rows["answers"]]
    
    # Batch query ChromaDB
    results = collection.query(
        query_texts=batch_queries,
        n_results=10
    )
    
    # Calculate precision for each query in batch
    for query_idx, (answer, documents) in enumerate(zip(batch_answers, results["documents"])):
        # Convert to set for faster lookup
        answer_matches = sum(1 for doc in documents if answer in doc)
        precision = answer_matches / len(documents)
        precisions.append(precision)

mean_precision = np.mean(precisions)
mean_precision
# They got 0.1246

100%|██████████| 3068/3068 [07:11<00:00,  7.11it/s]


np.float64(0.2131344925587507)

In [61]:
# Let's try whether the document_id matches

from datasets import concatenate_datasets
batch_size = 32  # Adjust based on your memory constraints
precisions = []

joined_dataset = dataset["train"]
for i in tqdm(range(0, len(joined_dataset), batch_size)):
    # Get batch of questions and answers
    batch_slice = slice(i, min(i + batch_size, len(joined_dataset)))
    batch_rows = joined_dataset[batch_slice]
    print(batch_rows[0])
    batch_queries = batch_rows["question"]
    batch_ground_truth_ids = [row["id"] for row in batch_rows]
    
    # Batch query ChromaDB
    results = collection.query(
        query_texts=batch_queries,
        n_results=10
    )
    
    # Calculate precision for each query in batch
    for query_idx, (ground_truth_id, ids) in enumerate(zip(batch_ground_truth_ids, results["ids"])):
        original_ids = [id.split("_")[0] for id in ids]
        precision = sum([1 for id in original_ids if id == ground_truth_id]) / len(original_ids)
        precisions.append(precision)

mean_precision = np.mean(precisions)
mean_precision

  0%|          | 0/2738 [00:00<?, ?it/s]


KeyError: 0