In [17]:
import json

import pandas as pd
from datasets import load_dataset
from evaluate import load
from dotenv import load_dotenv
from pathlib import Path


load_dotenv()
canonical_dataset_name = "scifact"
dataset_name = "scifact-bge-m3-sparse-vectors"

In [2]:
ds = load_dataset(f"nirantk/{dataset_name}", split="corpus")
print(ds)

Dataset({
    features: ['_id', 'title', 'text', 'bge_m3_sparse_vector'],
    num_rows: 5183
})


In [3]:
trec_eval = load("trec_eval")

## Example Qrels and Runs

In [4]:
qrel = {"query": [0], "q0": ["q0"], "docid": ["doc_1"], "rel": [2]}
run = {
    "query": [0, 0],
    "q0": ["q0", "q0"],
    "docid": ["doc_2", "doc_1"],
    "rank": [0, 1],
    "score": [1.5, 1.2],
    "system": ["test", "test"],
}
results = trec_eval.compute(predictions=[run], references=[qrel])

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


## Load reference Qrels from test.tsv

In [11]:
df = pd.read_csv(f"../../data/{canonical_dataset_name}/qrels/test.tsv", sep="\t")
df.head()

## Convert to qrel
qrel = {
    "query": [int(q) for q in df["query-id"].tolist()],
    "q0": ["q0"] * len(df),
    "docid": [str(d) for d in df["corpus-id"].tolist()],
    "rel": df["score"].tolist(),
}

In [12]:
def validate_data(predictions, references):
    # Define expected fields and types for predictions and references
    expected_pred_keys = {
        "query": int,
        "q0": str,
        "docid": str,
        "rank": int,
        "score": float,
        "system": str,
    }
    expected_ref_keys = {"query": int, "q0": str, "docid": str, "rel": int}

    # Function to validate each record against expected fields and types
    def check_record(record, expected_keys):
        for key, expected_type in expected_keys.items():
            if key not in record:
                return f"Missing key: {key}"
            if not all(isinstance(item, expected_type) for item in record[key]):
                return f"Incorrect type for key {key}, expected {expected_type}, got {type(record[key][0])}"

        # Check for consistent lengths across fields
        length = len(record[next(iter(record))])  # get length of first item
        if not all(len(value) == length for value in record.values()):
            return "Inconsistent lengths among fields"

        return "Valid"

    # Validate predictions and references
    pred_validation = check_record(predictions, expected_pred_keys)
    ref_validation = check_record(references, expected_ref_keys)

    return pred_validation, ref_validation

[PosixPath('bge_m3_retoken_reconstruct_sentence_piece_rescore_False.json')]

In [21]:
prediction_files = list(Path(".").glob("*.json"))
prediction_files

for file in prediction_files:
    with open(file) as f:
        run = json.load(f)

    validation = validate_data(run, qrel)

    print(f"File: {file}")
    results = trec_eval.compute(predictions=[run], references=[qrel])
    print(results["NDCG@10"])

File: bge_m3_retoken_reconstruct_sentence_piece_rescore_False.json
0.01660019702639439
File: bge_m3_retoken_reconstruct_bpe_rescore_False.json


  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()
  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


0.05675311925775384
File: bge_m3_retoken_reconstruct_bpe_rescore_True.json
0.5042811658445359
File: bge_m3_retoken_reconstruct_sentence_piece_rescore_True.json
0.3720435190652301


  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()
  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()
