In [11]:
import json
import os
from collections import Counter
from typing import Dict, Iterable

import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from remap_tokens import (
    aggregate_weights,
    calc_tf,
    filter_list_tokens,
    filter_pair_tokens,
    reconstruct_bpe,
    rescore_vector,
    snowball_tokenize,
    stem_list_tokens,
    stem_pair_tokens,
)
from tokenizers import Tokenizer
from tqdm.auto import tqdm

In [12]:
load_dotenv()

canonical_dataset_name = "scifact"
dataset_name = "scifact-bge-m3-sparse-vectors"
source_model = "nirantk/splade-v3-lexical"
source_col_name = "spalde-v3-lexical"
col_name = "splade-snowball-rescore-large"
collection_name = f"{dataset_name}-{col_name}"

In [13]:
ds = load_dataset(f"nirantk/{dataset_name}-sparse-vectors", split="train")
ds[source_col_name][0]

'{"1009": 0.20660123229026794, "1011": 0.036437395960092545, "1013": 0.05786345899105072, "1014": 0.5618454217910767, "1017": 0.12642613053321838, "1018": 0.06154331564903259, "1020": 0.00875066313892603, "1022": 0.0029254043474793434, "1050": 0.13071982562541962, "1052": 0.4093637466430664, "1059": 0.4673144817352295, "2011": 0.2652147710323334, "2019": 0.5381823778152466, "2029": 0.03925827145576477, "2064": 0.2539152204990387, "2076": 0.04394211992621422, "2077": 0.07884006202220917, "2084": 0.1510767638683319, "2093": 0.29473719000816345, "2109": 0.12470348179340363, "2132": 0.02887592278420925, "2141": 0.38374659419059753, "2181": 0.04394211992621422, "2184": 0.08782484382390976, "2213": 0.650757372379303, "2220": 0.6022219061851501, "2240": 0.5410289764404297, "2243": 0.7701805830001831, "2287": 0.401551216840744, "2300": 0.5506471395492554, "2317": 1.1381834745407104, "2321": 0.29619070887565613, "2335": 0.07341018319129944, "2336": 0.2637155055999756, "2367": 0.0029254043474793

In [14]:
source_sparse_vectors = [json.loads(x) for x in ds[source_col_name]]
tokenizer = Tokenizer.from_pretrained(source_model)
reverse_voc = {v: k for k, v in tokenizer.get_vocab().items()}

In [15]:
raw_vectors = []
for sv in source_sparse_vectors:
    raw_vectors.append(
        {
            "tokens": [reverse_voc[int(key)] for key in sv.keys()],
            "weights": list(sv.values()),
        }
    )

## Recombine and Retokenize

In [16]:
# len(raw_vectors.pop(2)["tokens"])

In [17]:
def retokenize_sparse_vector(text: str, source_sparse_vector: Dict[str, float], tokenizer: Tokenizer):
    total_tokens_overall = 0
    num_docs = 0
    max_token_weight, num_tokens, total_tokens = {}, {}, 0

    sequential_tokens = tokenizer.encode(text).tokens
    reconstructed = reconstruct_bpe(enumerate(sequential_tokens))

    # print("reconstructed:\t", reconstructed)

    filtered_reconstructed = filter_pair_tokens(reconstructed)

    # print("filtered:\t", filtered_reconstructed)

    stemmed_reconstructed = stem_pair_tokens(filtered_reconstructed)

    # print("stemmed:\t", stemmed_reconstructed)
    # print("weights:\t", source_sparse_vector["weights"])
    weighed_reconstructed = aggregate_weights(
        stemmed_reconstructed, source_sparse_vector["weights"]
    )

    # print("weighted:\t", weighed_reconstructed)

    total_tokens += len(weighed_reconstructed)
    max_token_weight, num_tokens = {}, {}
    for reconstructed_token, score in weighed_reconstructed:
        max_token_weight[reconstructed_token] = max(
            max_token_weight.get(reconstructed_token, 0), score
        )
        num_tokens[reconstructed_token] = num_tokens.get(reconstructed_token, 0) + 1

    # print()
    # tokens = stem_list_tokens(filter_list_tokens(snowball_tokenize(text)))
    # total_tokens = len(tokens)
    # num_tokens = Counter(tokens)
    reweighted_sparse_vector = {}
    token_score = rescore_vector(max_token_weight)
    # print("token_score:\t", token_score)
    for token, token_count in num_tokens.items():
        score = token_score.get(token)
        tf = score + token_count - 1
        reweighted_sparse_vector[token] = calc_tf(tf, total_tokens)


    total_tokens_overall += total_tokens
    num_docs += 1
    # print(len(reweighted_sparse_vector))
    # print("reweighted_sparse_vector:\t", reweighted_sparse_vector)
    if not len(reweighted_sparse_vector) <= 1.2 * len(source_sparse_vector["tokens"]):
        print(reweighted_sparse_vector)
        print(source_sparse_vector)
        print(len(reweighted_sparse_vector), len(source_sparse_vector["tokens"]))
        raise ValueError("Something went wrong")
    return reweighted_sparse_vector


reweighted_sparse_vectors = []
for source_sparse_vector, text in tqdm(
    zip(raw_vectors, ds["text"]), total=len(raw_vectors)
):  
    reweighted_sparse_vector = retokenize_sparse_vector(source_sparse_vector=source_sparse_vector, text=text, tokenizer=tokenizer)
    # print(len(source_sparse_vectors))
    reweighted_sparse_vectors.append(reweighted_sparse_vector)
    # print(len(reweighted_sparse_vector))

  0%|          | 0/5183 [00:00<?, ?it/s]

In [18]:
# Find length of each sparse vector
vector_lengths = [len(sv) for sv in reweighted_sparse_vectors]

# Percentile of the lengths
np.percentile(vector_lengths, [10, 50, 90])

array([38., 45., 53.])

In [19]:
# len(reweighted_sparse_vectors), reweighted_sparse_vectors[0]

## Upload to Qdrant

In [20]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
client.collectio

def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0


# client.delete_collection(collection_name)

AttributeError: 'QdrantClient' object has no attribute 'collectio'

In [None]:
def reset_collection(client: QdrantClient, collection_name: str):
    if client.collection_exists(collection_name):
        client.delete_collection(collection_name)
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            col_name: models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )

In [None]:
# Make a vocab of all keys in the reweighted sparse vectors
vocab = set()
for sv in reweighted_sparse_vectors:
    vocab.update(sv.keys())

In [None]:
len(vocab)

In [None]:
# Convert this into a vocab object with each string having an id
vocab = {word: i for i, word in enumerate(vocab)}
invert_vocab = {i: word for word, i in vocab.items()}

In [None]:
# Recompute the reweighted sparse vectors with the new vocab
id_reweighted_sparse_vectors = []
for sv in tqdm(reweighted_sparse_vectors):
    new_sv = {}
    for word, weight in sv.items():
        new_sv[vocab[word]] = weight
    id_reweighted_sparse_vectors.append(new_sv)

In [None]:
def batched(iterable: Iterable, n: int = 1) -> Iterable:
    """Yield successive n-sized chunks from iterable."""
    for i in range(0, len(iterable), n):
        yield iterable[i : i + n]

In [None]:
def make_points(
    reweighted_sparse_vectors: Dict, ds: Dataset
) -> Iterable[models.PointStruct]:
    points = []
    for sv, element in tqdm(zip(reweighted_sparse_vectors, ds)):
        points.append(
            models.PointStruct(
                id=int(element["_id"]),
                vector={col_name: convert_sparse_vector(sv)},
                payload={
                    "text": element["text"],
                    "title": element["title"],
                    "id": element["_id"],
                },
            )
        )
    return points


# next(read_data(id_reweighted_sparse_vectors, ds))
reset_collection(client, collection_name)
points = make_points(id_reweighted_sparse_vectors, ds)
# Run ONCE to upload data, only when collection is empty
for batch in tqdm(batched(points, 100)):
    try:
        client.upload_points(collection_name=collection_name, points=batch)
    except Exception as e:
        print(e)
        pass

## Queries

In [None]:
test = pd.read_csv(f"../data/{dataset_name}/qrels/test.tsv", sep="\t")
test["query-id"] = test["query-id"].astype(int)

with open(f"../data/{dataset_name}/queries.jsonl") as f:
    queries = [json.loads(line) for line in f]

# Only keep the test set queries
queries = [q for q in queries if int(q["_id"]) in list(test["query-id"])]
len(queries)

In [None]:
tokenizer = Tokenizer.from_pretrained("nirantk/splade-v3-lexical")
tokens = [tokenizer.encode(q["text"]).tokens for q in queries]
tokens = [list(set(t)) for t in tokens]
# tokens = [list(set(t.ids)) for t in tokens]

In [None]:
idx = 50
tokens[idx]

In [None]:
# assign weight to all tokens and create a query vector with tokens and weights as keys
query_vectors = []
for token in tokens:
    query_vector = {}
    query_vector["tokens"] = token
    query_vector["weights"] = [1] * len(token)
    query_vectors.append(query_vector)

In [None]:
# Retokenize all the query tokens
reweighted_query_tokens = []
for qv, text in tqdm(zip(query_vectors, [q["text"] for q in queries])):
    # print(text)
    # print(qv)
    reweighted_query_tokens.append(retokenize_sparse_vector(source_sparse_vector=qv, text=text, tokenizer=tokenizer))

In [None]:
reweighted_query_tokens[idx+1]

In [None]:
np.percentile([len(t) for t in reweighted_query_tokens], [10, 50, 90])

In [None]:
vocab

In [None]:
# Map the keys back to the original vocab with integer ids
id_reweighted_query_tokens = []
for qv in tqdm(reweighted_query_tokens):
    new_qv = {}
    for word, weight in qv.items():
        try:
            new_qv[vocab[word]] = weight    
        except KeyError:
            print(word)
            continue
    id_reweighted_query_tokens.append(new_qv)

In [None]:
qdrant_query_vectors = [
    models.SparseVector(
        indices=qv.keys(),
        values=qv.values(),
    )
    for qv in id_reweighted_query_tokens
]

In [None]:
qdrant_query_vectors[idx]

In [None]:
limit = 10
results = []
for qv in tqdm(qdrant_query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name=col_name, vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

In [None]:
query_ids, doc_ids, ranks, scores  = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [str(r.id) for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [int(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["splade"],
}

with open("lexical-retokenize-rescore.run.json", "w") as f:
    json.dump(run, f, indent=2)