In [69]:
from FlagEmbedding import BGEM3FlagModel
from tqdm.notebook import tqdm
from dotenv import load_dotenv
import os
load_dotenv()

model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [70]:
import weaviate
import weaviate.classes as wvc
client = weaviate.connect_to_custom(http_host=os.getenv("WEAVIATE_HTTP_HOST"),http_port=int(os.getenv("WEAVIATE_HTTP_PORT")), http_secure=True, grpc_host=os.getenv("WEAVIATE_GRPC_HOST"), grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT")), grpc_secure=True, auth_credentials=weaviate.auth.AuthApiKey(api_key=os.getenv("WEAVIATE_API_KEY")))

In [74]:
client.collections.delete("hc4_filtered_bge_m3")

In [75]:
try:
    documents = client.collections.create(
        name="hc4_filtered_bge_m3",
        vectorizer_config=[
            wvc.config.Configure.NamedVectors.none(
                name="title_dense"   
            ),
            wvc.config.Configure.NamedVectors.none(  
                name="text_dense"
            ),
        ],
        properties=[
            wvc.config.Property(
                name="doc_id",
                data_type=wvc.config.DataType.UUID,
            ),
            wvc.config.Property(
                name="title_sparse",
                data_type=wvc.config.DataType.BLOB,
            ),
              wvc.config.Property(
                name="document_sparse",
                data_type=wvc.config.DataType.BLOB,
            ),
            wvc.config.Property(
                name="title_colbert",
                data_type=wvc.config.DataType.BLOB,
            ),
           wvc.config.Property(
                name="document_colbert",
                data_type=wvc.config.DataType.BLOB,
            ),
            wvc.config.Property(
                name="title",
                data_type=wvc.config.DataType.TEXT,
            ),
            wvc.config.Property(
                name="text",
                data_type=wvc.config.DataType.TEXT,
            ),
            wvc.config.Property(
                name="url",
                data_type=wvc.config.DataType.TEXT,
            )
        ])
except Exception as e:
    print(e)
    

In [55]:
import pandas as pd
from explore.funcs import load_datasets

datasets = load_datasets(["zh", "ru", "fa"])

docs = pd.concat([pd.DataFrame(dataset.docs_iter()) for dataset in datasets.values()])

len(docs)

1876367

In [6]:
import base64
def to_blob(obj):
    return base64.b64encode(pickle.dumps(obj)).decode('utf-8')

In [56]:
import pickle

batches = [docs[i:i+10000] for i in range(0, len(docs), 10000)]
zh = client.collections.get("neuclir_bge_m3")

outer_progress = tqdm(total=len(docs))

for i, batch in enumerate(batches):
    title_embeddings = model.encode(batch["title"].to_list(), return_dense=True, return_sparse=True, return_colbert_vecs=True)
    #doc_embeddings = model.encode(batch["text"].to_list(), return_dense=True, return_sparse=True, return_colbert_vecs=False)
    title_sparse_blobs = [to_blob(x) for x in title_embeddings["lexical_weights"]]
    title_colbert_blobs = [to_blob(x) for x in title_embeddings["colbert_vecs"]]
    batch = batch.reset_index(drop=True)
    with zh.batch.fixed_size(batch_size=200, concurrent_requests=10) as b:
        for row in batch.itertuples(index=True):
            b.add_object(properties={
                "doc_id": row.doc_id,
                "title": row.title,
                "text": row.text,
                "url": row.url,
                "title_sparse": title_sparse_blobs[row.Index],
                "title_colbert": title_colbert_blobs[row.Index],
            }, vector={  
                "title_dense": title_embeddings["dense_vecs"][row.Index],
            }, uuid=row.doc_id)
            outer_progress.update(1)

  0%|          | 0/1876367 [00:00<?, ?it/s]



Inference Embeddings:   0%|          | 0/40 [00:00<?, ?it/s][A[A

Inference Embeddings:   2%|▎         | 1/40 [00:01<01:09,  1.77s/it][A[A

Inference Embeddings:   5%|▌         | 2/40 [00:02<00:45,  1.20s/it][A[A

Inference Embeddings:   8%|▊         | 3/40 [00:03<00:37,  1.01s/it][A[A

Inference Embeddings:  10%|█         | 4/40 [00:04<00:32,  1.11it/s][A[A

Inference Embeddings:  12%|█▎        | 5/40 [00:04<00:29,  1.20it/s][A[A

Inference Embeddings:  15%|█▌        | 6/40 [00:05<00:26,  1.27it/s][A[A

Inference Embeddings:  18%|█▊        | 7/40 [00:06<00:25,  1.28it/s][A[A

Inference Embeddings:  20%|██        | 8/40 [00:06<00:23,  1.34it/s][A[A

Inference Embeddings:  22%|██▎       | 9/40 [00:07<00:22,  1.38it/s][A[A

Inference Embeddings:  25%|██▌       | 10/40 [00:08<00:21,  1.42it/s][A[A

Inference Embeddings:  28%|██▊       | 11/40 [00:08<00:20,  1.44it/s][A[A

Inference Embeddings:  30%|███       | 12/40 [00:09<00:19,  1.45it/s][A[A

Inference Embed

KeyboardInterrupt: 

In [65]:

zh = client.collections.get("neuclir_bge_m3")
aggregation = zh.aggregate.over_all(total_count=True)
print(aggregation.total_count)

2194
