In [28]:
%load_ext autoreload
%autoreload 2

In [29]:
import getpass
import json
import os
from typing import Dict, Iterable, List

import pandas as pd
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from FlagEmbedding import BGEM3FlagModel
from transformers import AutoTokenizer
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from tokenizers import Tokenizer
from tqdm.auto import tqdm

In [11]:
load_dotenv()
FORCE_DELETE = False
canonical_dataset_name = "scifact"
dataset_name = "scifact-bge-m3-sparse-vectors"
col_name = "bge_m3_sparse_vector"
collection_name = f"{dataset_name}-{col_name}"

In [3]:
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))

In [4]:
def is_empty(client: QdrantClient, collection_name: str) -> bool:
    return client.get_collection(collection_name).points_count == 0

In [5]:
if FORCE_DELETE:
    client.delete_collection(collection_name)

True

In [6]:
ds_raw = load_dataset(f"nirantk/{dataset_name}", split="corpus")

In [7]:
ds_raw

Dataset({
    features: ['_id', 'title', 'text', 'bge_m3_sparse_vector'],
    num_rows: 5183
})

In [8]:
ds = ds_raw.to_list()

In [9]:
for element in tqdm(ds):
    element[col_name] = json.loads(element[col_name])

raw_vectors = [element[col_name] for element in ds]
raw_vectors[0]

  0%|          | 0/5183 [00:00<?, ?it/s]

{'39176': 0.1639404296875,
 '21094': 0.033599853515625,
 '159958': 0.1788330078125,
 '119856': 0.1939697265625,
 '35011': 0.1964111328125,
 '26866': 0.2216796875,
 '70': 0.011077880859375,
 '168698': 0.161865234375,
 '14135': 0.04254150390625,
 '78574': 0.1883544921875,
 '831': 0.051239013671875,
 '52490': 0.16845703125,
 '8231': 0.067626953125,
 '70760': 0.1358642578125,
 '34754': 0.1903076171875,
 '136': 0.01042938232421875,
 '16750': 0.024810791015625,
 '23': 0.01120758056640625,
 '123309': 0.1346435546875,
 '164462': 0.1981201171875,
 '13315': 0.131591796875,
 '44954': 0.168701171875,
 '45755': 0.1553955078125,
 '92105': 0.1864013671875,
 '9': 0.01116943359375,
 '165598': 0.1431884765625,
 '297': 0.010650634765625,
 '214706': 0.0733642578125,
 '3332': 0.016510009765625,
 '191': 0.01358795166015625,
 '7154': 0.00965118408203125,
 '86898': 0.06939697265625,
 '177': 0.0108184814453125,
 '594': 0.03509521484375,
 '16625': 0.197265625,
 '16': 0.0110626220703125,
 '944': 0.052734375,
 '3

In [10]:
def read_data(dataset_name: str):
    ds = load_dataset(f"nirantk/{dataset_name}", split="corpus")
    print("Columns: ", ds.features)
    ds = ds.to_list()
    return ds

def to_points(ds: Dataset) -> Iterable[models.PointStruct]:
    for element in tqdm(ds):
        yield models.PointStruct(
                id=int(element["_id"]),
                vector={col_name: convert_sparse_vector(json.loads(element[col_name]))},
                payload={
                    "text": element["text"],
                    "title": element["title"],
                    "id": element["_id"],
                },
            )


# if collection does not exist, create it
if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            col_name: models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )


def batch_iterator(iterable, batch_size=128):
    """
    Iterates over an iterable in batches of a given size.

    Args:
        iterable: An iterable object.
        batch_size: The size of each batch.

    Yields:
        A batch of items from the iterable.
    """

    l = len(iterable)
    for ndx in range(0, l, batch_size):
        yield iterable[ndx : min(ndx + batch_size, l)]


# Example usage:

for batch in batch_iterator(range(10), 12):
    print(batch)

# Run ONCE to upload data, only when collection is empty
if is_empty(client, collection_name):
    ds = read_data(dataset_name)
    points = to_points(ds)
    print("Uploading data")
    client.upload_points(
        collection_name=collection_name,
        points=to_points(ds)
    )

range(0, 10)
Columns:  {'_id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'bge_m3_sparse_vector': Value(dtype='string', id=None)}
Uploading data


  0%|          | 0/5183 [00:00<?, ?it/s]

## Queries

In [12]:
test = pd.read_csv(f"../data/{canonical_dataset_name}/qrels/test.tsv", sep="\t")
test["query-id"] = test["query-id"].astype(int)

In [13]:
test["query-id"].value_counts()
test[test["query-id"] == 873]

Unnamed: 0,query-id,corpus-id,score
213,873,1180972,1
214,873,19307912,1
215,873,27393799,1
216,873,29025270,1
217,873,3315558,1


In [14]:
with open(f"../data/{canonical_dataset_name}/queries.jsonl") as f:
    queries = [json.loads(line) for line in f]

# Only keep the test set queries
queries = [q for q in queries if int(q["_id"]) in list(test["query-id"])]
len(queries)

300

In [15]:
queries[0]

{'_id': '1',
 'text': '0-dimensional biomaterials show inductive properties.',
 'metadata': {}}

## Create query vectors

In [16]:
model = BGEM3FlagModel(
    "BAAI/bge-m3", use_fp16=True
)  # Setting use_fp16 to True speeds up computation with a slight performance degradation

def get_sparse_vector(batch: List[str]):
    output = model.encode(
        batch, return_dense=False, return_sparse=True, return_colbert_vecs=False
    )
    return output["lexical_weights"]

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [17]:
query_vectors = get_sparse_vector([q["text"] for q in queries])

Inference Embeddings: 100%|██████████| 25/25 [00:04<00:00,  5.65it/s]


In [18]:
query_vectors = [
    models.SparseVector(
        indices=query.keys(),
        values=query.values(),
    )
    for query in query_vectors
]

In [19]:
query_vectors[2]

SparseVector(indices=[106, 139217, 23, 17274, 765, 1563, 33176, 15853, 683, 40523, 2481, 5], values=[0.1339111328125, 0.2822265625, 0.13134765625, 0.25927734375, 0.11083984375, 0.1455078125, 0.19921875, 0.2161865234375, 0.2015380859375, 0.2188720703125, 0.11700439453125, 0.0173797607421875])

In [20]:
limit = 10
results = []
for qv in tqdm(query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name=col_name, vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

  0%|          | 0/300 [00:00<?, ?it/s]

In [21]:
query_ids, doc_ids, ranks, scores = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [str(r.id) for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [int(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["splade"],
}

with open("bge-m3-lexical.run.json", "w") as f:
    json.dump(run, f, indent=2)

## Retokenize and Store that run

In [22]:
import json
import os
from collections import Counter
from typing import Dict, Iterable

import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_sparse_tools import convert_sparse_vector
from remap_tokens import (
    aggregate_weights,
    calc_tf,
    filter_list_tokens,
    filter_pair_tokens,
    reconstruct_bpe,
    rescore_vector,
    snowball_tokenize,
    stem_list_tokens,
    stem_pair_tokens,
)
from tokenizers import Tokenizer
from tqdm.auto import tqdm

In [23]:
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
reverse_vocab = {v: k for k, v in tokenizer.get_vocab().items()}

In [24]:
# Remap raw vectors to weights and tokens
corpus_sparse_vectors = []
for element in raw_vectors:
    tokens = list(element.keys())
    tokens = [reverse_vocab[int(token)] for token in tokens]
    sparse_vector = {
        "weights": list(element.values()),
        "tokens": tokens,
    }
    corpus_sparse_vectors.append(sparse_vector)

In [35]:
def retokenize_sparse_vector(text: str, source_sparse_vector: Dict[str, float], tokenizer: Tokenizer):
    total_tokens_overall = 0
    num_docs = 0
    max_token_weight, num_tokens, total_tokens = {}, {}, 0

    sequential_tokens = tokenizer.encode(text)
    sequential_tokens = [reverse_vocab[t] for t in sequential_tokens]
    reconstructed = reconstruct_bpe(enumerate(sequential_tokens))


    # print("reconstructed:\t", reconstructed)

    filtered_reconstructed = filter_pair_tokens(reconstructed)

    # print("filtered:\t", filtered_reconstructed)

    stemmed_reconstructed = stem_pair_tokens(filtered_reconstructed)

    # print("stemmed:\t", stemmed_reconstructed)
    # print("weights:\t", source_sparse_vector["weights"])
    weighed_reconstructed = aggregate_weights(
        stemmed_reconstructed, source_sparse_vector["weights"]
    )

    # print("weighted:\t", weighed_reconstructed)

    total_tokens += len(weighed_reconstructed)
    max_token_weight, num_tokens = {}, {}
    for reconstructed_token, score in weighed_reconstructed:
        max_token_weight[reconstructed_token] = max(
            max_token_weight.get(reconstructed_token, 0), score
        )
        num_tokens[reconstructed_token] = num_tokens.get(reconstructed_token, 0) + 1

    # print()
    # tokens = stem_list_tokens(filter_list_tokens(snowball_tokenize(text)))
    # total_tokens = len(tokens)
    # num_tokens = Counter(tokens)
    reweighted_sparse_vector = {}
    token_score = rescore_vector(max_token_weight)
    # print("token_score:\t", token_score)
    for token, token_count in num_tokens.items():
        score = token_score.get(token)
        tf = score + token_count - 1
        reweighted_sparse_vector[token] = calc_tf(tf, total_tokens)


    total_tokens_overall += total_tokens
    num_docs += 1
    # print(len(reweighted_sparse_vector))
    # print("reweighted_sparse_vector:\t", reweighted_sparse_vector)
    # if not len(reweighted_sparse_vector) <= 1.2 * len(source_sparse_vector["tokens"]):
    #     print(reweighted_sparse_vector)
    #     print(source_sparse_vector)
    #     print(len(reweighted_sparse_vector), len(source_sparse_vector["tokens"]))
    #     raise ValueError("Something went wrong")
    return reweighted_sparse_vector


reweighted_sparse_vectors = []
for source_sparse_vector, text in tqdm(
    zip(corpus_sparse_vectors, ds_raw["text"]), total=len(corpus_sparse_vectors)
):  
    reweighted_sparse_vector = retokenize_sparse_vector(source_sparse_vector=source_sparse_vector, text=text, tokenizer=tokenizer)
    # print(len(source_sparse_vectors))
    reweighted_sparse_vectors.append(reweighted_sparse_vector)
    # print(len(reweighted_sparse_vector))

  0%|          | 0/5183 [00:00<?, ?it/s]

In [36]:
# Find length of each sparse vector
vector_lengths = [len(sv) for sv in reweighted_sparse_vectors]

# Percentile of the lengths
np.percentile(vector_lengths, [10, 50, 90])

array([107., 158., 221.])

In [37]:
def reset_collection(client: QdrantClient, collection_name: str):
    if client.collection_exists(collection_name):
        client.delete_collection(collection_name)
    client.create_collection(
        collection_name=collection_name,
        vectors_config={},
        sparse_vectors_config={
            col_name: models.SparseVectorParams(
                index=models.SparseIndexParams(on_disk=False)
            )
        },
    )

In [39]:
# Make a vocab of all keys in the reweighted sparse vectors
vocab = set()
for sv in reweighted_sparse_vectors:
    vocab.update(sv.keys())

# Convert this into a vocab object with each string having an id
vocab = {word: i for i, word in enumerate(vocab)}
invert_vocab = {i: word for word, i in vocab.items()}

In [40]:
# Recompute the reweighted sparse vectors with the new vocab
id_reweighted_sparse_vectors = []
for sv in tqdm(reweighted_sparse_vectors):
    new_sv = {}
    for word, weight in sv.items():
        new_sv[vocab[word]] = weight
    id_reweighted_sparse_vectors.append(new_sv)

  0%|          | 0/5183 [00:00<?, ?it/s]

In [41]:
def batched(iterable: Iterable, n: int = 1) -> Iterable:
    """Yield successive n-sized chunks from iterable."""
    for i in range(0, len(iterable), n):
        yield iterable[i : i + n]

In [42]:

def make_points(
    reweighted_sparse_vectors: Dict, ds: Dataset
) -> Iterable[models.PointStruct]:
    points = []
    for sv, element in tqdm(zip(reweighted_sparse_vectors, ds)):
        points.append(
            models.PointStruct(
                id=int(element["_id"]),
                vector={col_name: convert_sparse_vector(sv)},
                payload={
                    "text": element["text"],
                    "title": element["title"],
                    "id": element["_id"],
                },
            )
        )
    return points

collection_name = f"{collection_name}-retok"
# next(read_data(id_reweighted_sparse_vectors, ds))
reset_collection(client, collection_name)
points = make_points(id_reweighted_sparse_vectors, ds)
# Run ONCE to upload data, only when collection is empty
for batch in tqdm(batched(points, 100)):
    try:
        client.upload_points(collection_name=collection_name, points=batch)
    except Exception as e:
        print(e)
        pass

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [47]:
query_vectors[0].values

[0.2288818359375,
 0.046051025390625,
 0.2142333984375,
 0.17333984375,
 0.271728515625,
 0.10894775390625,
 0.169677734375,
 0.264404296875,
 0.1953125,
 0.209716796875,
 0.060150146484375]

In [50]:
wv_keyed_query_vectors = []
for qv in query_vectors:
    new_qv = {}
    new_qv["weights"] = qv.values
    new_qv["tokens"] = [reverse_vocab[i] for i in qv.indices]
    wv_keyed_query_vectors.append(new_qv)

In [52]:
wv_keyed_query_vectors[0]

{'weights': [0.2288818359375,
  0.046051025390625,
  0.2142333984375,
  0.17333984375,
  0.271728515625,
  0.10894775390625,
  0.169677734375,
  0.264404296875,
  0.1953125,
  0.209716796875,
  0.060150146484375],
 'tokens': ['▁0',
  '-',
  'dimensional',
  '▁bio',
  'material',
  's',
  '▁show',
  '▁induc',
  'tive',
  '▁properties',
  '.']}

In [53]:
# Retokenize all the query tokens
reweighted_query_tokens = []
for qv, text in tqdm(zip(wv_keyed_query_vectors, [q["text"] for q in queries])):
    # print(text)
    # print(qv)
    reweighted_query_tokens.append(retokenize_sparse_vector(source_sparse_vector=qv, text=text, tokenizer=tokenizer))

0it [00:00, ?it/s]

In [54]:
# Map the keys back to the original vocab with integer ids
id_reweighted_query_tokens = []
for qv in tqdm(reweighted_query_tokens):
    new_qv = {}
    for word, weight in qv.items():
        try:
            new_qv[vocab[word]] = weight    
        except KeyError:
            print(word)
            continue
    id_reweighted_query_tokens.append(new_qv)

  0%|          | 0/300 [00:00<?, ?it/s]

/2000
ordin
▁alb
▁galli
▁gab
mmel
uer
▁adept
▁pola
▁mata
stes
▁tira
tiv
▁casu


In [55]:
qdrant_query_vectors = [
    models.SparseVector(
        indices=qv.keys(),
        values=qv.values(),
    )
    for qv in id_reweighted_query_tokens
]

In [56]:
limit = 10
results = []
for qv in tqdm(qdrant_query_vectors):
    try:
        result = client.search(
            collection_name=collection_name,
            query_vector=models.NamedSparseVector(name=col_name, vector=qv),
            with_payload=True,
            limit=limit,
        )
        results.append(result)
    except Exception as e:
        print(e)
        print(qv)
        results.append(None)

  0%|          | 0/300 [00:00<?, ?it/s]

In [57]:
query_ids, doc_ids, ranks, scores  = [], [], [], []
for query, result in zip(queries, results):
    query_id = query["_id"]
    result_ids = [str(r.id) for r in result]
    result_scores = [r.score for r in result]
    result_ranks = list(range(len(result)))
    query_ids.extend(len(result) * [query_id])
    doc_ids.extend(result_ids)
    ranks.extend(result_ranks)
    scores.extend(result_scores)
    # print(f"query: {query_id}")
    # print(f"docid: {result_ids}")
    # print(f"rank: {result_ranks}")
    # print(f"score: {result_scores}")

run = {
    "query": [int(q) for q in query_ids],
    "q0": len(query_ids) * ["q0"],
    "docid": doc_ids,
    "rank": ranks,
    "score": scores,
    "system": len(query_ids) * ["splade"],
}

with open("bge-m3-retokenize-rescore.run.json", "w") as f:
    json.dump(run, f, indent=2)