In [2]:
import json

import pandas as pd
from datasets import load_dataset
from evaluate import load
from dotenv import load_dotenv

load_dotenv()
dataset_name = "scifact"

In [3]:
ds = load_dataset(f"nirantk/{dataset_name}-sparse-vectors", split="train")
print(ds)

Dataset({
    features: ['spalde-v3-lexical', '_id', 'text', 'title'],
    num_rows: 5183
})


In [4]:
trec_eval = load("trec_eval")

## Example Qrels and Runs

In [5]:
qrel = {
    "query": [0],
    "q0": ["q0"],
    "docid": ["doc_1"],
    "rel": [2]
}
run = {
    "query": [0, 0],
    "q0": ["q0", "q0"],
    "docid": ["doc_2", "doc_1"],
    "rank": [0, 1],
    "score": [1.5, 1.2],
    "system": ["test", "test"]
}
results = trec_eval.compute(predictions=[run], references=[qrel])

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


## Load reference Qrels from test.tsv

In [6]:
df = pd.read_csv(f"../data/{dataset_name}/qrels/test.tsv", sep="\t")
df.head()

## Convert to qrel
qrel = {
    "query": [int(q) for q in df["query-id"].tolist()],
    "q0": ["q0"] * len(df),
    "docid": [str(d) for d in df["corpus-id"].tolist()],
    "rel": df["score"].tolist()
}

In [7]:
def validate_data(predictions, references):
    # Define expected fields and types for predictions and references
    expected_pred_keys = {
        'query': int, 'q0': str, 'docid': str, 'rank': int, 'score': float, 'system': str
    }
    expected_ref_keys = {
        'query': int, 'q0': str, 'docid': str, 'rel': int
    }

    # Function to validate each record against expected fields and types
    def check_record(record, expected_keys):
        for key, expected_type in expected_keys.items():
            if key not in record:
                return f"Missing key: {key}"
            if not all(isinstance(item, expected_type) for item in record[key]):
                return f"Incorrect type for key {key}, expected {expected_type}, got {type(record[key][0])}"

        # Check for consistent lengths across fields
        length = len(record[next(iter(record))])  # get length of first item
        if not all(len(value) == length for value in record.values()):
            return "Inconsistent lengths among fields"

        return "Valid"

    # Validate predictions and references
    pred_validation = check_record(predictions, expected_pred_keys)
    ref_validation = check_record(references, expected_ref_keys)

    return pred_validation, ref_validation

In [8]:
with open("splade-v3-lexical.run.json") as f:
    run = json.load(f)

# validate_data(run, qrel)
results = trec_eval.compute(predictions=[run], references=[qrel])

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


In [9]:
results

{'runid': 'splade',
 'num_ret': 3000,
 'num_rel': 339,
 'num_rel_ret': 274,
 'num_q': 300,
 'map': 0.6478440476190476,
 'gm_map': 0.09698013276165883,
 'bpref': 0.0,
 'Rprec': 0.5674444444444445,
 'recip_rank': 0.657978835978836,
 'P@5': 0.16266666666666665,
 'P@10': 0.09133333333333332,
 'P@15': 0.06088888888888889,
 'P@20': 0.04566666666666666,
 'P@30': 0.030444444444444444,
 'P@100': 0.009133333333333334,
 'P@200': 0.004566666666666667,
 'P@500': 0.0018266666666666668,
 'P@1000': 0.0009133333333333334,
 'NDCG@5': 0.6640296439959622,
 'NDCG@10': 0.6917459236155028,
 'NDCG@15': 0.6917459236155028,
 'NDCG@20': 0.6917459236155028,
 'NDCG@30': 0.6917459236155028,
 'NDCG@100': 0.6917459236155028,
 'NDCG@200': 0.6917459236155028,
 'NDCG@500': 0.6917459236155028,
 'NDCG@1000': 0.6917459236155028}

In [10]:
with open("lexical-retokenize.run.json") as f:
    run = json.load(f)

# validate_data(run, qrel)
results = trec_eval.compute(predictions=[run], references=[qrel])

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


In [11]:
results

{'runid': 'splade',
 'num_ret': 3,
 'num_rel': 339,
 'num_rel_ret': 0,
 'num_q': 3,
 'map': 0.0,
 'gm_map': nan,
 'bpref': 0.0,
 'Rprec': 0.0,
 'recip_rank': 0.0,
 'P@5': 0.0,
 'P@10': 0.0,
 'P@15': 0.0,
 'P@20': 0.0,
 'P@30': 0.0,
 'P@100': 0.0,
 'P@200': 0.0,
 'P@500': 0.0,
 'P@1000': 0.0,
 'NDCG@5': 0.0,
 'NDCG@10': 0.0,
 'NDCG@15': 0.0,
 'NDCG@20': 0.0,
 'NDCG@30': 0.0,
 'NDCG@100': 0.0,
 'NDCG@200': 0.0,
 'NDCG@500': 0.0,
 'NDCG@1000': 0.0}