In [10]:
%load_ext autoreload
%autoreload 2

import json
import os
import sys
from collections import Counter
from typing import Dict, Iterable, List

import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from FlagEmbedding import BGEM3FlagModel
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from remap_tokens import (
    aggregate_weights,
    calc_tf,
    filter_list_tokens,
    filter_pair_tokens,
    reconstruct_sentence_piece,
    rescore_vector,
    snowball_tokenize,
    stem_list_tokens,
    stem_pair_tokens,
)
from tokenizers import Tokenizer
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
load_dotenv()

canonical_dataset_name = "scifact"
dataset_name = "scifact-bge-m3-sparse-vectors"
source_col_name = col_name = "bge_m3_sparse_vector"
collection_name = f"{dataset_name}-{col_name}-retok"
model_name = "BAAI/bge-m3"

In [6]:
ds = load_dataset(f"nirantk/{dataset_name}", split="corpus")
ds[col_name][0]

'{"39176": 0.1639404296875, "21094": 0.033599853515625, "159958": 0.1788330078125, "119856": 0.1939697265625, "35011": 0.1964111328125, "26866": 0.2216796875, "70": 0.011077880859375, "168698": 0.161865234375, "14135": 0.04254150390625, "78574": 0.1883544921875, "831": 0.051239013671875, "52490": 0.16845703125, "8231": 0.067626953125, "70760": 0.1358642578125, "34754": 0.1903076171875, "136": 0.01042938232421875, "16750": 0.024810791015625, "23": 0.01120758056640625, "123309": 0.1346435546875, "164462": 0.1981201171875, "13315": 0.131591796875, "44954": 0.168701171875, "45755": 0.1553955078125, "92105": 0.1864013671875, "9": 0.01116943359375, "165598": 0.1431884765625, "297": 0.010650634765625, "214706": 0.0733642578125, "3332": 0.016510009765625, "191": 0.01358795166015625, "7154": 0.00965118408203125, "86898": 0.06939697265625, "177": 0.0108184814453125, "594": 0.03509521484375, "16625": 0.197265625, "16": 0.0110626220703125, "944": 0.052734375, "3956": 0.0084228515625, "1492": 0.152

In [7]:
sparse_vectors = [json.loads(x) for x in ds[source_col_name]]


In [29]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
vocab = tokenizer.get_vocab()
reverse_vocab = {v: k for k, v in vocab.items()}

## Recombine and Retokenize

In [15]:
raw_vectors = []
for sv in sparse_vectors:
    raw_vectors.append(
        {
            "tokens": [reverse_vocab[int(key)] for key in sv.keys()],
            "weights": list(sv.values()),
        }
    )

In [38]:
def retokenize_sparse_vector(text: str, source_sparse_vector: Dict[str, float], tokenizer: Tokenizer):
    total_tokens_overall = 0
    num_docs = 0
    max_token_weight, num_tokens, total_tokens = {}, {}, 0

    sequential_tokens = tokenizer.tokenize(text)
    reconstructed = reconstruct_sentence_piece(sequential_tokens)

    # print("reconstructed:\t", reconstructed)

    filtered_reconstructed = filter_pair_tokens(reconstructed)

    # print("filtered:\t", filtered_reconstructed)

    stemmed_reconstructed = stem_pair_tokens(filtered_reconstructed)

    # print("stemmed:\t", stemmed_reconstructed)
    # print("weights:\t", source_sparse_vector["weights"])
    weighed_reconstructed = aggregate_weights(
        stemmed_reconstructed, source_sparse_vector["weights"]
    )

    # print("weighted:\t", weighed_reconstructed)

    total_tokens += len(weighed_reconstructed)
    max_token_weight, num_tokens = {}, {}
    for reconstructed_token, score in weighed_reconstructed:
        max_token_weight[reconstructed_token] = max(
            max_token_weight.get(reconstructed_token, 0), score
        )
        num_tokens[reconstructed_token] = num_tokens.get(reconstructed_token, 0) + 1

    # print()
    # tokens = stem_list_tokens(filter_list_tokens(snowball_tokenize(text)))
    # total_tokens = len(tokens)
    # num_tokens = Counter(tokens)
    reweighted_sparse_vector = {}
    token_score = rescore_vector(max_token_weight)
    # print("token_score:\t", token_score)
    for token, token_count in num_tokens.items():
        score = token_score.get(token)
        tf = score + token_count - 1
        reweighted_sparse_vector[token] = calc_tf(tf, total_tokens)


    total_tokens_overall += total_tokens
    num_docs += 1
    # print(len(reweighted_sparse_vector))
    # print("reweighted_sparse_vector:\t", reweighted_sparse_vector)
    if not len(reweighted_sparse_vector) <= 1.2 * len(source_sparse_vector["tokens"]):
        print(reweighted_sparse_vector)
        print(source_sparse_vector)
        print(len(reweighted_sparse_vector), len(source_sparse_vector["tokens"]))
        raise ValueError("Something went wrong")
    return reweighted_sparse_vector


reweighted_sparse_vectors = []
for source_sparse_vector, text in tqdm(
    zip(raw_vectors, ds["text"]), total=len(raw_vectors)
):  
    reweighted_sparse_vector = retokenize_sparse_vector(source_sparse_vector=source_sparse_vector, text=text, tokenizer=tokenizer)
    # print(len(source_sparse_vectors))
    reweighted_sparse_vectors.append(reweighted_sparse_vector)
    # print(len(reweighted_sparse_vector))

  0%|          | 0/5183 [00:00<?, ?it/s]

In [39]:
# Find length of each sparse vector
vector_lengths = [len(sv) for sv in reweighted_sparse_vectors]

# Percentile of the lengths
np.percentile(vector_lengths, [10, 50, 90])

array([30., 54., 85.])

In [19]:
# len(reweighted_sparse_vectors), reweighted_sparse_vectors[0]

## Upload to Qdrant

In [40]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))


def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0


# client.delete_collection(collection_name)

In [41]:
def reset_collection(client: QdrantClient, collection_name: str):
    if client.collection_exists(collection_name):
        client.delete_collection(collection_name)
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            col_name: models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )

In [42]:
# Make a vocab of all keys in the reweighted sparse vectors
vocab = set()
for sv in reweighted_sparse_vectors:
    vocab.update(sv.keys())

In [43]:
len(vocab)

6002

In [44]:
# Convert this into a vocab object with each string having an id
vocab = {word: i for i, word in enumerate(vocab)}
invert_vocab = {i: word for word, i in vocab.items()}

In [45]:
# Recompute the reweighted sparse vectors with the new vocab
id_reweighted_sparse_vectors = []
for sv in tqdm(reweighted_sparse_vectors):
    new_sv = {}
    for word, weight in sv.items():
        new_sv[vocab[word]] = weight
    id_reweighted_sparse_vectors.append(new_sv)

  0%|          | 0/5183 [00:00<?, ?it/s]

In [46]:
def batched(iterable: Iterable, n: int = 1) -> Iterable:
    """Yield successive n-sized chunks from iterable."""
    for i in range(0, len(iterable), n):
        yield iterable[i : i + n]

In [47]:
def make_points(
    reweighted_sparse_vectors: Dict, ds: Dataset
) -> Iterable[models.PointStruct]:
    points = []
    for sv, element in tqdm(zip(reweighted_sparse_vectors, ds)):
        points.append(
            models.PointStruct(
                id=int(element["_id"]),
                vector={col_name: convert_sparse_vector(sv)},
                payload={
                    "text": element["text"],
                    "title": element["title"],
                    "id": element["_id"],
                },
            )
        )
    return points


# next(read_data(id_reweighted_sparse_vectors, ds))
reset_collection(client, collection_name)
points = make_points(id_reweighted_sparse_vectors, ds)
# Run ONCE to upload data, only when collection is empty
for batch in tqdm(batched(points, 100)):
    try:
        client.upload_points(collection_name=collection_name, points=batch)
    except Exception as e:
        print(e)
        pass

0it [00:00, ?it/s]

0it [00:00, ?it/s]

## Queries

In [49]:
test = pd.read_csv(f"../data/{canonical_dataset_name}/qrels/test.tsv", sep="\t")
test["query-id"] = test["query-id"].astype(int)

with open(f"../data/{canonical_dataset_name}/queries.jsonl") as f:
    queries = [json.loads(line) for line in f]

# Only keep the test set queries
queries = [q for q in queries if int(q["_id"]) in list(test["query-id"])]
len(queries)

300

In [50]:
model = BGEM3FlagModel(
    "BAAI/bge-m3", use_fp16=True
)  # Setting use_fp16 to True speeds up computation with a slight performance degradation

def get_sparse_vector(batch: List[str]):
    output = model.encode(
        batch, return_dense=False, return_sparse=True, return_colbert_vecs=False
    )
    return output["lexical_weights"]

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [51]:
%%time
raw_query_vectors = get_sparse_vector([q["text"] for q in queries])

Inference Embeddings: 100%|██████████| 25/25 [00:02<00:00,  9.85it/s]

CPU times: user 1.61 s, sys: 290 ms, total: 1.9 s
Wall time: 2.54 s





In [52]:
idx = 50
raw_query_vectors[idx]

defaultdict(int,
            {'3980': 0.1531,
             '25388': 0.1818,
             '7': 0.03915,
             '57913': 0.147,
             '34080': 0.0839,
             '3038': 0.06354,
             '112': 0.002756,
             '8': 0.09863,
             '180220': 0.1117,
             '1409': 0.04636,
             '11044': 0.131,
             '199334': 0.2042,
             '23417': 0.2339,
             '40715': 0.2086,
             '450': 0.03105,
             '351': 0.0846,
             '3284': 0.1382,
             '10484': 0.08466})

In [53]:
def make_float_query_vectors(raw_query_vectors: List[Dict[str, float]]) -> List[Dict[str, List[float]]]:
    float_query_vectors = []
    for sv in raw_query_vectors:
        new_sv = {}
        new_sv["tokens"] = list(sv.keys())
        new_sv["weights"] = [float(v) for v in sv.values()]
        float_query_vectors.append(new_sv)
    return float_query_vectors

float_query_vectors = make_float_query_vectors(raw_query_vectors)

In [59]:
reweighted_query_vectors = []
for source_query_vector, text in tqdm(zip(float_query_vectors, [q["text"] for q in queries])):
    reweighted_qv = retokenize_sparse_vector(source_sparse_vector=source_query_vector, text=text, tokenizer=tokenizer)
    reweighted_query_vectors.append(reweighted_qv)

0it [00:00, ?it/s]

In [60]:
np.percentile([len(t) for t in reweighted_query_vectors], [10, 50, 90])

array([ 3.,  6., 12.])

In [62]:
reweighted_query_vectors[0]

{'dimension': 1.60583970427897,
 'materi': 1.711582750511634,
 'tive': 1.8413300733527267}

In [63]:
# Map the keys back to the original vocab with integer ids
id_reweighted_query_tokens = []
for qv in tqdm(reweighted_query_vectors):
    new_qv = {}
    for word, weight in qv.items():
        try:
            new_qv[vocab[word]] = weight    
        except KeyError:
            print(word)
            continue
    id_reweighted_query_tokens.append(new_qv)

  0%|          | 0/300 [00:00<?, ?it/s]

/2000
ordin
▁difficil
▁surviv
▁surviv
▁year
▁perform
mmel
uer
▁2.
▁system
▁diseas
stes
tiv


In [64]:
qdrant_query_vectors = [
    models.SparseVector(
        indices=qv.keys(),
        values=qv.values(),
    )
    for qv in id_reweighted_query_tokens
]

In [65]:
qdrant_query_vectors[idx]

SparseVector(indices=[2994, 4544, 2118, 2841, 4280, 1177, 1339, 243, 4701, 4436, 3322], values=[1.6690328878382839, 1.8074913885123627, 1.1815423365369104, 1.30626862811058, 1.379653324012354, 1.4625501566243453, 1.557535616122145, 1.1279511510106308, 1.2406559308186282, 1.0791091549313259, 1.0343905775447468])

In [66]:
limit = 10
results = []
for qv in tqdm(qdrant_query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name=col_name, vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

  0%|          | 0/300 [00:00<?, ?it/s]

In [67]:
query_ids, doc_ids, ranks, scores  = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [str(r.id) for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [int(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["bge-m3"],
}

with open("bge-m3-sentence-piece-pair-rescore.run.json", "w") as f:
    json.dump(run, f, indent=2)