In [3]:
from pprint import pprint
from typing import List
from FlagEmbedding import BGEM3FlagModel
from llama_index.core.base.embeddings.base import Embedding
from llama_index.core.embeddings import BaseEmbedding
import numpy as np
model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True)

model.encode("Hello World", batch_size=12, max_length=8192)["dense_vecs"]

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 15992.52it/s]
You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


array([-0.05627,  0.02858, -0.01721, ...,  0.02463, -0.0355 ,  0.0144 ],
      dtype=float16)

In [4]:
from llama_index.core.node_parser.text import SentenceSplitter
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=0)

In [8]:
import datasets
from tqdm import tqdm

# https://huggingface.co/datasets/hotpotqa/hotpot_qa?row=16
# https://arxiv.org/pdf/1606.05250
dataset = datasets.load_dataset("rajpurkar/squad")

# Split the context into chunks
full_chunks = []
for i in tqdm(range(len(dataset["train"]))):
    row = dataset["train"][i]
    chunks = text_splitter.split_text(row["context"])
    full_chunks.extend([
        {
            "id": f"{row['id']}_{i}",
            "chunk": chunks[i],
        }
        for i in range(len(chunks))
    ])

print("Number of chunks:", len(full_chunks))
print("Original length:", len(dataset["train"]))
full_chunks[:10]


100%|██████████| 87599/87599 [00:20<00:00, 4358.91it/s]

Number of chunks: 158543
Original length: 87599





[{'id': '5733be284776f41900661182_0',
  'chunk': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection.'},
 {'id': '5733be284776f41900661182_1',
  'chunk': 'It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.'},
 {'id': '5733be284776f4190066117f_0',
  'chunk': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in fro

In [10]:
# Get the embeddings
embeddings = model.encode([chunk["chunk"] for chunk in full_chunks], batch_size=12, max_length=8192)["dense_vecs"]

print(len(embeddings))
embeddings[:10]

Inference Embeddings: 100%|██████████| 13212/13212 [22:27<00:00,  9.81it/s]  


158543


array([[ 0.03357 , -0.001778, -0.0349  , ...,  0.03693 , -0.0222  ,
        -0.000498],
       [ 0.007965,  0.01884 , -0.0699  , ...,  0.03506 , -0.03018 ,
        -0.02544 ],
       [ 0.03357 , -0.001778, -0.0349  , ...,  0.03693 , -0.0222  ,
        -0.000498],
       ...,
       [ 0.007965,  0.01884 , -0.0699  , ...,  0.03506 , -0.03018 ,
        -0.02544 ],
       [ 0.03357 , -0.001778, -0.0349  , ...,  0.03693 , -0.0222  ,
        -0.000498],
       [ 0.007965,  0.01884 , -0.0699  , ...,  0.03506 , -0.03018 ,
        -0.02544 ]], dtype=float16)

In [11]:
# Save embeddings to file as a checkpoint
np.save("squad_embeddings.npy", embeddings)


In [17]:
from chromadb import Documents, EmbeddingFunction, Embeddings
class MyEmbeddingFunction(EmbeddingFunction):
    def __call__(self, input: Documents) -> Embeddings:
        embeddings = model.encode(input, batch_size=12, max_length=8192)["dense_vecs"]
        return embeddings


In [18]:
from chromadb import Client

chroma_client = Client()
collection = chroma_client.create_collection(name="squad2", embedding_function=MyEmbeddingFunction())
batch_size = 1000
for i in tqdm(range(0, len(embeddings), batch_size)):
    end_idx = min(i + batch_size, len(embeddings))  # Ensure we don't go past the end
    collection.add(
        embeddings=embeddings[i:end_idx],
        documents=[chunk["chunk"] for chunk in full_chunks[i:end_idx]],
        ids=[chunk["id"] for chunk in full_chunks[i:end_idx]]
    )

UniqueConstraintError: Collection squad1 already exists

In [16]:
# Search collection
collection.query(
    query_texts=["What is the capital of France?"],
    n_results=10
)

/Users/aidand/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz:   4%|▍         | 3.56M/79.3M [00:07<02:42, 489kiB/s] 


KeyboardInterrupt: 