In [1]:
import json
import os
from typing import Dict, Iterable

from datasets import load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from tqdm.auto import tqdm

In [2]:
load_dotenv()
dataset_name = "scifact"
col_name = "spalde-v3-lexical"
collection_name = f"{dataset_name}-{col_name}"

In [3]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))

client.get_collections()

CollectionsResponse(collections=[CollectionDescription(name='scifact-spalde-v3-lexical')])

In [4]:
client.recreate_collection(
    collection_name=collection_name,
    vectors_config={},
    sparse_vectors_config={
        "splade": models.SparseVectorParams(
            index=models.SparseIndexParams(on_disk=False)
        )
    },
)

  client.recreate_collection(


True

In [5]:
def convert_sparse_vector(sparse_vector: Dict) -> models.SparseVector:
    indices = []
    values = []

    for idx, value in sparse_vector.items():
        indices.append(int(idx))
        values.append(value)

    return models.SparseVector(indices=indices, values=values)

In [6]:
def read_data(dataset_name: str) -> Iterable[models.PointStruct]:
    ds = load_dataset(f"nirantk/{dataset_name}-sparse-vectors", split="train")
    print("Columns: ", ds.features)
    ds = ds.to_list()
    for element in ds:
        yield models.PointStruct(
            id=int(element["_id"]),
            vector={"splade": convert_sparse_vector(json.loads(element[col_name]))},
            payload={
                "text": element["text"],
                "title": element["title"],
            },
        )


client.upload_points(
    collection_name=collection_name, points=tqdm(read_data(dataset_name))
)

0it [00:00, ?it/s]

Columns:  {'spalde-v3-lexical': Value(dtype='string', id=None), '_id': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None)}


## Queries

In [7]:
queries = load_dataset(f"BeIR/{dataset_name}", "queries")["queries"]

In [10]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_pretrained("nirantk/splade-v3-lexical")

tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]

In [15]:
tokenizer.encode("hello world")

Encoding(num_tokens=4, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [14]:
tokenizer.encode(queries["text"][:5])

TypeError: TextInputSequence must be str