In [27]:
import getpass
import json
import os
from typing import Dict, Iterable, List

import pandas as pd
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from FlagEmbedding import BGEM3FlagModel
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from tokenizers import Tokenizer
from tqdm.auto import tqdm

load_dotenv()

True

In [21]:
load_dotenv()
canonical_dataset_name = "scifact"
dataset_name = "scifact-bge-m3-sparse-vectors"
col_name = "bge_m3_sparse_vector"
collection_name = f"{dataset_name}-{col_name}"

In [4]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))

In [5]:
def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0

In [6]:
client.delete_collection(collection_name)

False

In [7]:
ds = load_dataset(f"nirantk/{dataset_name}", split="corpus")

Downloading readme:   0%|          | 0.00/432 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating corpus split:   0%|          | 0/5183 [00:00<?, ? examples/s]

In [8]:
ds

Dataset({
    features: ['_id', 'title', 'text', 'bge_m3_sparse_vector'],
    num_rows: 5183
})

In [13]:
ds = ds.to_list()

In [14]:
for element in tqdm(ds):
    element[col_name] = json.loads(element[col_name])

  0%|          | 0/5183 [00:00<?, ?it/s]

In [19]:
def read_data(dataset_name: str):
    ds = load_dataset(f"nirantk/{dataset_name}", split="corpus")
    print("Columns: ", ds.features)
    ds = ds.to_list()
    return ds

def to_points(ds: Dataset) -> Iterable[models.PointStruct]:
    for element in tqdm(ds):
        yield models.PointStruct(
                id=int(element["_id"]),
                vector={col_name: convert_sparse_vector(json.loads(element[col_name]))},
                payload={
                    "text": element["text"],
                    "title": element["title"],
                    "id": element["_id"],
                },
            )


# if collection does not exist, create it
if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            col_name: models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )


def batch_iterator(iterable, batch_size=128):
    """
    Iterates over an iterable in batches of a given size.

    Args:
        iterable: An iterable object.
        batch_size: The size of each batch.

    Yields:
        A batch of items from the iterable.
    """

    l = len(iterable)
    for ndx in range(0, l, batch_size):
        yield iterable[ndx : min(ndx + batch_size, l)]


# Example usage:

for batch in batch_iterator(range(10), 12):
    print(batch)

# Run ONCE to upload data, only when collection is empty
if is_empty(client, collection_name):
    ds = read_data(dataset_name)
    points = to_points(ds)
    print("Uploading data")
    client.upload_points(
        collection_name=collection_name,
        points=to_points(ds)
    )

range(0, 10)
Columns:  {'_id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'bge_m3_sparse_vector': Value(dtype='string', id=None)}
Uploading data


  0%|          | 0/5183 [00:00<?, ?it/s]

## Queries

In [22]:
test = pd.read_csv(f"../data/{canonical_dataset_name}/qrels/test.tsv", sep="\t")
test["query-id"] = test["query-id"].astype(int)

In [23]:
test["query-id"].value_counts()
test[test["query-id"] == 873]

Unnamed: 0,query-id,corpus-id,score
213,873,1180972,1
214,873,19307912,1
215,873,27393799,1
216,873,29025270,1
217,873,3315558,1


In [24]:
with open(f"../data/{canonical_dataset_name}/queries.jsonl") as f:
    queries = [json.loads(line) for line in f]

# Only keep the test set queries
queries = [q for q in queries if int(q["_id"]) in list(test["query-id"])]
len(queries)

300

In [29]:
queries[0]

{'_id': '1',
 'text': '0-dimensional biomaterials show inductive properties.',
 'metadata': {}}

## Create query vectors

In [28]:
model = BGEM3FlagModel(
    "BAAI/bge-m3", use_fp16=True
)  # Setting use_fp16 to True speeds up computation with a slight performance degradation

def get_sparse_vector(batch: List[str]):
    output = model.encode(
        batch, return_dense=False, return_sparse=True, return_colbert_vecs=False
    )
    return output["lexical_weights"]

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [30]:
query_vectors = get_sparse_vector([q["text"] for q in queries])

Inference Embeddings: 100%|██████████| 25/25 [00:08<00:00,  2.88it/s]


In [31]:
query_vectors = [
    models.SparseVector(
        indices=query.keys(),
        values=query.values(),
    )
    for query in query_vectors
]

In [32]:
query_vectors[2]

SparseVector(indices=[106, 139217, 23, 17274, 765, 1563, 33176, 15853, 683, 40523, 2481, 5], values=[0.1339111328125, 0.2822265625, 0.13134765625, 0.25927734375, 0.11083984375, 0.1455078125, 0.19921875, 0.2161865234375, 0.2015380859375, 0.2188720703125, 0.11700439453125, 0.0173797607421875])

In [33]:
limit = 10
results = []
for qv in tqdm(query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name=col_name, vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

  0%|          | 0/300 [00:00<?, ?it/s]

In [34]:
query_ids, doc_ids, ranks, scores = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [str(r.id) for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [int(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["splade"],
}

with open("bge-m3-lexical.run.json", "w") as f:
    json.dump(run, f, indent=2)