In [2]:
import json
import os
from typing import Dict, Iterable

from datasets import load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from tokenizers import Tokenizer
from tqdm.auto import tqdm

In [4]:
load_dotenv()

dataset_name = "scifact"
col_name = "spalde-v3-lexical"
collection_name = f"{dataset_name}-{col_name}"

In [8]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))

In [43]:
def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0

In [23]:
# client.delete_collection(collection_name)

In [26]:
if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            "splade": models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )

In [27]:
def convert_sparse_vector(sparse_vector: Dict) -> models.SparseVector:
    indices = []
    values = []

    for idx, value in sparse_vector.items():
        indices.append(int(idx))
        values.append(value)

    return models.SparseVector(indices=indices, values=values)

In [34]:
def read_data(dataset_name: str) -> Iterable[models.PointStruct]:
    ds = load_dataset(f"nirantk/{dataset_name}-sparse-vectors", split="train")
    print("Columns: ", ds.features)
    ds = ds.to_list()
    for element in ds:
        yield models.PointStruct(
            id=int(element["_id"]),
            vector={"splade": convert_sparse_vector(json.loads(element[col_name]))},
            payload={
                "text": element["text"],
                "title": element["title"],
                "id": element["_id"],
            },
        )


# Run ONCE to upload data, only when collection is empty
if is_empty(client, collection_name):
    client.upload_points(
        collection_name=collection_name, points=tqdm(read_data(dataset_name))
    )

0it [00:00, ?it/s]

Columns:  {'spalde-v3-lexical': Value(dtype='string', id=None), '_id': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None)}


## Queries

In [36]:
queries = load_dataset(f"BeIR/{dataset_name}", "queries")["queries"]

In [52]:
tokenizer = Tokenizer.from_pretrained("nirantk/splade-v3-lexical")
tokens = [tokenizer.encode(q) for q in queries["text"]]
tokens = [list(set(t.ids)) for t in tokens]

In [53]:
query_vectors = [
    models.SparseVector(
        indices=token,
        values=len(token) * [1.0],
    )
    for token in tokens
]

In [54]:
query_vectors[2]

SparseVector(indices=[8583, 6802, 3609, 5022, 11441, 9153, 22471, 1997, 2389, 2007, 3164, 101, 102, 4456, 2024, 1003, 2030, 1011, 18804, 1012, 1015], values=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

In [55]:
limit = 10
results = []
for qv in tqdm(query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name="splade", vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

  0%|          | 0/1109 [00:00<?, ?it/s]

In [65]:
query_ids, doc_ids, ranks, scores  = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [r.id for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [str(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["splade"],
}

with open("splade-v3-lexical.run.json", "w") as f:
    json.dump(run, f)