In [1]:
import numpy as np
import pandas as pd

DATA_PATH = "../data/nexkey_synthetic_dataset_v1"
queries = pd.read_csv(f"{DATA_PATH}/queries.csv")
properties = pd.read_csv(f"{DATA_PATH}/properties.csv")
interactions = pd.read_csv(f"{DATA_PATH}/interactions.csv")

print(queries.shape, properties.shape, interactions.shape)

(30000, 16) (15000, 27) (480000, 4)


In [2]:
rng = np.random.RandomState(7)

all_qids = queries["query_id"].unique()
rng.shuffle(all_qids)

n = len(all_qids)
train_qids = set(all_qids[:int(0.80*n)])
val_qids   = set(all_qids[int(0.80*n):int(0.90*n)])
test_qids  = set(all_qids[int(0.90*n):])

def split_interactions(df):
    train = df[df["query_id"].isin(train_qids)].copy()
    val   = df[df["query_id"].isin(val_qids)].copy()
    test  = df[df["query_id"].isin(test_qids)].copy()
    return train, val, test

train_int, val_int, test_int = split_interactions(interactions)

print("Interactions:")
print("train:", train_int.shape)
print("val  :", val_int.shape)
print("test :", test_int.shape)

Interactions:
train: (384000, 4)
val  : (48000, 4)
test : (48000, 4)


In [3]:
def build_ground_truth(df, rel_threshold=2):
    gt = (
        df[df["relevance"] >= rel_threshold]
        .groupby("query_id")["property_id"]
        .apply(set)
        .to_dict()
    )
    return gt

gt_val = build_ground_truth(val_int, rel_threshold=2)
gt_test = build_ground_truth(test_int, rel_threshold=2)

print("Val queries with >=1 relevant:", len(gt_val))
print("Test queries with >=1 relevant:", len(gt_test))

Val queries with >=1 relevant: 3000
Test queries with >=1 relevant: 3000


In [4]:
def recall_at_k(ranked_pids, relevant_set, k):
    return 1.0 if len(set(ranked_pids[:k]) & relevant_set) > 0 else 0.0

def dcg_at_k(ranked_pids, relevant_set, k):
    dcg = 0.0
    for i, pid in enumerate(ranked_pids[:k], start=1):
        rel = 1.0 if pid in relevant_set else 0.0
        dcg += rel / np.log2(i + 1)
    return dcg

def ndcg_at_k(ranked_pids, relevant_set, k):
    dcg = dcg_at_k(ranked_pids, relevant_set, k)
    # ideal DCG is all relevant hits at the top (binary relevance)
    ideal_hits = min(len(relevant_set), k)
    idcg = sum(1.0 / np.log2(i + 1) for i in range(1, ideal_hits + 1))
    return dcg / idcg if idcg > 0 else 0.0