In [None]:
import json
import os
from typing import Dict, Iterable

import pandas as pd
from datasets import load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from remap_tokens import (
    aggregate_weights,
    filter_pair_tokens,
    reconstruct_bpe,
    # rescore_vector,
    stem_pair_tokens,
)
from tokenizers import Tokenizer
from tqdm.auto import tqdm

In [None]:
load_dotenv()

dataset_name = "scifact"
source_model = "nirantk/splade-v3-lexical"
source_col_name = "spalde-v3-lexical"
col_name = "splade-snowball"
collection_name = f"{dataset_name}-{col_name}"

In [None]:
ds = load_dataset(f"nirantk/{dataset_name}-sparse-vectors", split="train")
ds[source_col_name][0]

In [None]:
source_sparse_vectors = [json.loads(x) for x in ds[source_col_name]]

In [None]:
tokenizer = Tokenizer.from_pretrained(source_model)
reverse_voc = {v: k for k, v in tokenizer.get_vocab().items()}

In [None]:
raw_vectors = []
for sv in source_sparse_vectors:
    raw_vectors.append(
        {
            "tokens": [reverse_voc[int(key)] for key in sv.keys()],
            "weights": list(sv.values()),
        }
    )

## Recombine and Retokenize

In [None]:
max_token_weight = {}
num_tokens = {}

total_tokens = 0

for sentence in raw_vectors:
    print("tokens:\t", sentence['tokens'])

    reconstructed = reconstruct_bpe(enumerate(sentence["tokens"]))

    print("reconstructed:\t", reconstructed)

    filtered_reconstructed = filter_pair_tokens(reconstructed)

    print("filtered:\t", filtered_reconstructed)
 
    stemmed_reconstructed = stem_pair_tokens(filtered_reconstructed)

    print("stemmed:\t", stemmed_reconstructed)

    weighed_reconstructed = aggregate_weights(
        stemmed_reconstructed, sentence["weights"]
    )

    print("weighed:\t", weighed_reconstructed)

    total_tokens += len(weighed_reconstructed)

    for reconstructed_token, score in weighed_reconstructed:
        max_token_weight[reconstructed_token] = max(
            max_token_weight.get(reconstructed_token, 0), score
        )
        num_tokens[reconstructed_token] = num_tokens.get(reconstructed_token, 0) + 1

    print()
    break

# tokens = stem_list_tokens(filter_list_tokens(snowball_tokenize(text)))
# total_tokens = len(tokens)
# num_tokens = Counter(tokens)

sparse_vector = {}

token_score = rescore_vector(max_token_weight)

for token, token_count in num_tokens.items():
    score = token_score[token]
    tf = score + token_count - 1
    # tf = token_count
    sparse_vector[token] = calc_tf(tf, total_tokens)

out_file.write(json.dumps(sparse_vector) + "\n")

total_tokens_overall += total_tokens
num_docs += 1

## Upload to Qdrant

In [None]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))


def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0


# client.delete_collection(collection_name)

In [None]:
if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            "splade": models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )

In [None]:
def read_data(dataset_name: str) -> Iterable[models.PointStruct]:
    ds = load_dataset(f"nirantk/{dataset_name}-sparse-vectors", split="train")
    print("Columns: ", ds.features)
    ds = ds.to_list()
    for element in ds:
        yield models.PointStruct(
            id=int(element["_id"]),
            vector={"splade": convert_sparse_vector(json.loads(element[col_name]))},
            payload={
                "text": element["text"],
                "title": element["title"],
                "id": element["_id"],
            },
        )


# Run ONCE to upload data, only when collection is empty
if is_empty(client, collection_name):
    client.upload_points(
        collection_name=collection_name, points=tqdm(read_data(dataset_name))
    )