In [None]:
%load_ext autoreload
%autoreload 2

import json
import os
from collections import Counter
from typing import Dict, Iterable

import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from retokenize import (
    aggregate_weights,
    process_text,
    reconstruct_bpe,
    stem_words,
)

from loguru import logger
from tokenizers import Tokenizer
from tqdm.auto import tqdm
from transformers import AutoTokenizer

In [None]:
load_dotenv()

canonical_dataset_name = "scifact"
dataset_name = "scifact-bge-m3-sparse-vectors"
col_name = "bge_m3_sparse_vector"
collection_name = f"{dataset_name}-{col_name}-retok"
model_name = "BAAI/bge-m3"

In [None]:
ds = load_dataset(f"nirantk/{dataset_name}", split="corpus")
ds[col_name][0]

In [None]:
sparse_vectors = [json.loads(x) for x in ds[col_name]]

### Change schema to Qdrant Sparse Vector

1. Create vocab and reverse vocab from Tokenizer corresponding to the model
2. Invert the integer index to token string
3. Make lists from the keys and values of the vocab dictionary

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
vocab = tokenizer.get_vocab()
reverse_vocab = {v: k for k, v in vocab.items()}

In [None]:
raw_vectors = []
for sv in sparse_vectors:
    raw_vectors.append(
        {
            "tokens": [reverse_vocab[int(key)] for key in sv.keys()],
            "weights": list(sv.values()),
        }
    )

In [None]:
ds["text"][0], raw_vectors[0]["tokens"][0:10], raw_vectors[0]["weights"][0:10]

## Recombine and Retokenize

In [None]:
rescored_vectors = []
# logger.level("DEBUG")
logger.level("INFO")
for source_sparse_vector in tqdm(raw_vectors):
    reconstructed_words = reconstruct_bpe(source_sparse_vector["tokens"])
    logger.debug(f"Reconstructed words: {reconstructed_words}")
    stemmed_words = stem_words(reconstructed_words)
    logger.debug(f"Stemmed words: {stemmed_words}")
    # reweighted_sparse_vector = process_text(
    #     bpe_tokens=source_sparse_vector["tokens"],
    #     weights=source_sparse_vector["weights"],
    #     aggregate_fn=aggregate_weights,
    # )
    # rescored_vectors.append(reweighted_sparse_vector)

# reweighted_sparse_vectors = []
# for source_sparse_vector, text in tqdm(
#     zip(raw_vectors, ds["text"]), total=len(raw_vectors)
# ):
#     reweighted_sparse_vector = process_text(
#         bpe_tokens=source_sparse_vector["tokens"],
#         weights=source_sparse_vector["weights"],
#         aggregate_fn=aggregate_weights,
#     )
#     # print(len(source_sparse_vectors))
#     reweighted_sparse_vectors.append(reweighted_sparse_vector)
#     # print(len(reweighted_sparse_vector))

In [None]:
# Find length of each sparse vector
vector_lengths = [len(sv) for sv in reweighted_sparse_vectors]

# Percentile of the lengths
np.percentile(vector_lengths, [10, 50, 90])

In [None]:
# len(reweighted_sparse_vectors), reweighted_sparse_vectors[0]

## Upload to Qdrant

In [None]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
client.collectio


def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0


# client.delete_collection(collection_name)

In [None]:
def reset_collection(client: QdrantClient, collection_name: str):
    if client.collection_exists(collection_name):
        client.delete_collection(collection_name)
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            col_name: models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )

In [None]:
# Make a vocab of all keys in the reweighted sparse vectors
vocab = set()
for sv in reweighted_sparse_vectors:
    vocab.update(sv.keys())

In [None]:
len(vocab)

In [None]:
# Convert this into a vocab object with each string having an id
vocab = {word: i for i, word in enumerate(vocab)}
invert_vocab = {i: word for word, i in vocab.items()}

In [None]:
# Recompute the reweighted sparse vectors with the new vocab
id_reweighted_sparse_vectors = []
for sv in tqdm(reweighted_sparse_vectors):
    new_sv = {}
    for word, weight in sv.items():
        new_sv[vocab[word]] = weight
    id_reweighted_sparse_vectors.append(new_sv)

In [None]:
def batched(iterable: Iterable, n: int = 1) -> Iterable:
    """Yield successive n-sized chunks from iterable."""
    for i in range(0, len(iterable), n):
        yield iterable[i : i + n]

In [None]:
def make_points(
    reweighted_sparse_vectors: Dict, ds: Dataset
) -> Iterable[models.PointStruct]:
    points = []
    for sv, element in tqdm(zip(reweighted_sparse_vectors, ds)):
        points.append(
            models.PointStruct(
                id=int(element["_id"]),
                vector={col_name: convert_sparse_vector(sv)},
                payload={
                    "text": element["text"],
                    "title": element["title"],
                    "id": element["_id"],
                },
            )
        )
    return points


# next(read_data(id_reweighted_sparse_vectors, ds))
reset_collection(client, collection_name)
points = make_points(id_reweighted_sparse_vectors, ds)
# Run ONCE to upload data, only when collection is empty
for batch in tqdm(batched(points, 100)):
    try:
        client.upload_points(collection_name=collection_name, points=batch)
    except Exception as e:
        print(e)
        pass

## Queries

In [None]:
test = pd.read_csv(f"../data/{dataset_name}/qrels/test.tsv", sep="\t")
test["query-id"] = test["query-id"].astype(int)

with open(f"../data/{dataset_name}/queries.jsonl") as f:
    queries = [json.loads(line) for line in f]

# Only keep the test set queries
queries = [q for q in queries if int(q["_id"]) in list(test["query-id"])]
len(queries)

In [None]:
tokenizer = Tokenizer.from_pretrained("nirantk/splade-v3-lexical")
tokens = [tokenizer.encode(q["text"]).tokens for q in queries]
tokens = [list(set(t)) for t in tokens]
# tokens = [list(set(t.ids)) for t in tokens]

In [None]:
idx = 50
tokens[idx]

In [None]:
# assign weight to all tokens and create a query vector with tokens and weights as keys
query_vectors = []
for token in tokens:
    query_vector = {}
    query_vector["tokens"] = token
    query_vector["weights"] = [1] * len(token)
    query_vectors.append(query_vector)

In [None]:
# Retokenize all the query tokens
reweighted_query_tokens = []
for qv, text in tqdm(zip(query_vectors, [q["text"] for q in queries])):
    # print(text)
    # print(qv)
    reweighted_query_tokens.append(
        retokenize_sparse_vector(
            source_sparse_vector=qv, text=text, tokenizer=tokenizer
        )
    )

In [None]:
reweighted_query_tokens[idx + 1]

In [None]:
np.percentile([len(t) for t in reweighted_query_tokens], [10, 50, 90])

In [None]:
vocab

In [None]:
# Map the keys back to the original vocab with integer ids
id_reweighted_query_tokens = []
for qv in tqdm(reweighted_query_tokens):
    new_qv = {}
    for word, weight in qv.items():
        try:
            new_qv[vocab[word]] = weight
        except KeyError:
            print(word)
            continue
    id_reweighted_query_tokens.append(new_qv)

In [None]:
qdrant_query_vectors = [
    models.SparseVector(
        indices=qv.keys(),
        values=qv.values(),
    )
    for qv in id_reweighted_query_tokens
]

In [None]:
qdrant_query_vectors[idx]

In [None]:
limit = 10
results = []
for qv in tqdm(qdrant_query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name=col_name, vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

In [None]:
query_ids, doc_ids, ranks, scores = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [str(r.id) for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [int(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["splade"],
}

with open("lexical-retokenize-rescore.run.json", "w") as f:
    json.dump(run, f, indent=2)