In [120]:
import json
import os
from typing import Dict, Iterable

import pandas as pd
from datasets import load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from tokenizers import Tokenizer
from tqdm.auto import tqdm

In [4]:
load_dotenv()

dataset_name = "scifact"
col_name = "spalde-v3-lexical"
collection_name = f"{dataset_name}-{col_name}"

In [8]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))

In [43]:
def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0

In [23]:
# client.delete_collection(collection_name)

In [26]:
if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            "splade": models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )

In [34]:
def read_data(dataset_name: str) -> Iterable[models.PointStruct]:
    ds = load_dataset(f"nirantk/{dataset_name}-sparse-vectors", split="train")
    print("Columns: ", ds.features)
    ds = ds.to_list()
    for element in ds:
        yield models.PointStruct(
            id=int(element["_id"]),
            vector={"splade": convert_sparse_vector(json.loads(element[col_name]))},
            payload={
                "text": element["text"],
                "title": element["title"],
                "id": element["_id"],
            },
        )


# Run ONCE to upload data, only when collection is empty
if is_empty(client, collection_name):
    client.upload_points(
        collection_name=collection_name, points=tqdm(read_data(dataset_name))
    )

0it [00:00, ?it/s]

Columns:  {'spalde-v3-lexical': Value(dtype='string', id=None), '_id': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None)}


## Queries

In [97]:
test = pd.read_csv(f"../data/{dataset_name}/qrels/test.tsv", sep="\t")
test["query-id"] = test["query-id"].astype(int)

In [109]:
test["query-id"].value_counts()
test[test["query-id"] == 873]

Unnamed: 0,query-id,corpus-id,score
213,873,1180972,1
214,873,19307912,1
215,873,27393799,1
216,873,29025270,1
217,873,3315558,1


In [112]:
with open(f"../data/{dataset_name}/queries.jsonl") as f:
    queries = [json.loads(line) for line in f]

# Only keep the test set queries
queries = [q for q in queries if int(q["_id"]) in list(test["query-id"])]
len(queries)

300

In [114]:
tokenizer = Tokenizer.from_pretrained("nirantk/splade-v3-lexical")
tokens = [tokenizer.encode(q["text"]) for q in queries]
tokens = [list(set(t.ids)) for t in tokens]

In [115]:
query_vectors = [
    models.SparseVector(
        indices=token,
        values=len(token) * [1.0],
    )
    for token in tokens
]

In [116]:
query_vectors[2]

SparseVector(indices=[28032, 13433, 101, 102, 19470, 1999, 2031, 2866, 7730, 1012, 1013, 1015, 2456, 2361, 10975], values=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

In [107]:
limit = 10
results = []
for qv in tqdm(query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name="splade", vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

  0%|          | 0/300 [00:00<?, ?it/s]

In [118]:
query_ids, doc_ids, ranks, scores  = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [str(r.id) for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [int(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["splade"],
}

with open("splade-v3-lexical.run.json", "w") as f:
    json.dump(run, f, indent=2)