In [5]:
import json

import pandas as pd
from datasets import load_dataset
from evaluate import load
from dotenv import load_dotenv

load_dotenv()
canonical_dataset_name = "scifact"
dataset_name = "scifact-bge-m3-sparse-vectors"

In [6]:
ds = load_dataset(f"nirantk/{dataset_name}", split="corpus")
print(ds)

Dataset({
    features: ['_id', 'title', 'text', 'bge_m3_sparse_vector'],
    num_rows: 5183
})


In [7]:
trec_eval = load("trec_eval")

## Example Qrels and Runs

In [8]:
qrel = {
    "query": [0],
    "q0": ["q0"],
    "docid": ["doc_1"],
    "rel": [2]
}
run = {
    "query": [0, 0],
    "q0": ["q0", "q0"],
    "docid": ["doc_2", "doc_1"],
    "rank": [0, 1],
    "score": [1.5, 1.2],
    "system": ["test", "test"]
}
results = trec_eval.compute(predictions=[run], references=[qrel])

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


## Load reference Qrels from test.tsv

In [9]:
df = pd.read_csv(f"../data/{canonical_dataset_name}/qrels/test.tsv", sep="\t")
df.head()

## Convert to qrel
qrel = {
    "query": [int(q) for q in df["query-id"].tolist()],
    "q0": ["q0"] * len(df),
    "docid": [str(d) for d in df["corpus-id"].tolist()],
    "rel": df["score"].tolist()
}

In [10]:
def validate_data(predictions, references):
    # Define expected fields and types for predictions and references
    expected_pred_keys = {
        'query': int, 'q0': str, 'docid': str, 'rank': int, 'score': float, 'system': str
    }
    expected_ref_keys = {
        'query': int, 'q0': str, 'docid': str, 'rel': int
    }

    # Function to validate each record against expected fields and types
    def check_record(record, expected_keys):
        for key, expected_type in expected_keys.items():
            if key not in record:
                return f"Missing key: {key}"
            if not all(isinstance(item, expected_type) for item in record[key]):
                return f"Incorrect type for key {key}, expected {expected_type}, got {type(record[key][0])}"

        # Check for consistent lengths across fields
        length = len(record[next(iter(record))])  # get length of first item
        if not all(len(value) == length for value in record.values()):
            return "Inconsistent lengths among fields"

        return "Valid"

    # Validate predictions and references
    pred_validation = check_record(predictions, expected_pred_keys)
    ref_validation = check_record(references, expected_ref_keys)

    return pred_validation, ref_validation

In [11]:
with open("bge-m3-lexical.run.json") as f:
    run = json.load(f)

# validate_data(run, qrel)
results = trec_eval.compute(predictions=[run], references=[qrel])

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


In [9]:
results

{'runid': 'splade',
 'num_ret': 3000,
 'num_rel': 339,
 'num_rel_ret': 274,
 'num_q': 300,
 'map': 0.6478440476190476,
 'gm_map': 0.09698013276165883,
 'bpref': 0.0,
 'Rprec': 0.5674444444444445,
 'recip_rank': 0.657978835978836,
 'P@5': 0.16266666666666665,
 'P@10': 0.09133333333333332,
 'P@15': 0.06088888888888889,
 'P@20': 0.04566666666666666,
 'P@30': 0.030444444444444444,
 'P@100': 0.009133333333333334,
 'P@200': 0.004566666666666667,
 'P@500': 0.0018266666666666668,
 'P@1000': 0.0009133333333333334,
 'NDCG@5': 0.6640296439959622,
 'NDCG@10': 0.6917459236155028,
 'NDCG@15': 0.6917459236155028,
 'NDCG@20': 0.6917459236155028,
 'NDCG@30': 0.6917459236155028,
 'NDCG@100': 0.6917459236155028,
 'NDCG@200': 0.6917459236155028,
 'NDCG@500': 0.6917459236155028,
 'NDCG@1000': 0.6917459236155028}

In [14]:
with open("lexical-retokenize-rescore.run.json") as f:
    run = json.load(f)

# validate_data(run, qrel)
results = trec_eval.compute(predictions=[run], references=[qrel])

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


In [15]:
results

{'runid': 'splade',
 'num_ret': 3000,
 'num_rel': 339,
 'num_rel_ret': 203,
 'num_q': 300,
 'map': 0.3955436507936508,
 'gm_map': 0.009728080496644581,
 'bpref': 0.0,
 'Rprec': 0.30149999999999993,
 'recip_rank': 0.40745105820105815,
 'P@5': 0.11133333333333335,
 'P@10': 0.06766666666666665,
 'P@15': 0.04511111111111111,
 'P@20': 0.033833333333333326,
 'P@30': 0.022555555555555554,
 'P@100': 0.006766666666666667,
 'P@200': 0.0033833333333333337,
 'P@500': 0.0013533333333333333,
 'P@1000': 0.0006766666666666667,
 'NDCG@5': 0.41786554234196616,
 'NDCG@10': 0.4518226303875962,
 'NDCG@15': 0.4518226303875962,
 'NDCG@20': 0.4518226303875962,
 'NDCG@30': 0.4518226303875962,
 'NDCG@100': 0.4518226303875962,
 'NDCG@200': 0.4518226303875962,
 'NDCG@500': 0.4518226303875962,
 'NDCG@1000': 0.4518226303875962}