In [1]:
from recbole.quick_start import load_data_and_model
from recbole.utils.case_study import full_sort_topk
from tqdm import tqdm
import pandas as pd


# 1. 加载模型和数据
config, model, dataset, train_data, valid_data, test_data = load_data_and_model(
    model_file='../checkpoint_saved/ml-1m/BPR-Jun-08-2025_14-26-19.pth'
)

# Step 1: Get all internal user IDs
all_uids = list(range(dataset.user_num))

# Step 2: Filter out users who have no interactions in the test set
valid_uids = [uid for uid in tqdm(all_uids) if test_data.uid2history_item[uid] is not None]

# Step 3: Convert to Series
import numpy as np
uid_series = np.array(valid_uids)

# Step 4: Run full_sort_topk
topk_scores, topk_index = full_sort_topk(uid_series, model, test_data, k=10, device=config['device'])

# Step 5: Convert internal item IDs to external tokens
external_item_lists = [dataset.id2token(dataset.iid_field, row.cpu().tolist()) for row in topk_index]
external_user_list = [dataset.id2token(dataset.uid_field, [uid])[0] for uid in uid_series]

# Step 6: Save as DataFrame
df = pd.DataFrame({
    'user_id': external_user_list,
    'topk_items': [','.join(items) for items in external_item_lists]
})
display(df.head())
df.to_csv('outputs/all_user_top10.csv', index=False)
print("save all_user_top10 successfully")

08 Jun 18:42    INFO  
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 42
state = INFO
reproducibility = True
data_path = datasets/ml-1m
checkpoint_dir = ../checkpoint_saved/ml-1m/
show_progress = True
save_dataset = True
dataset_save_path = None
save_dataloaders = True
dataloaders_save_path = None
log_wandb = True

Training Hyper Parameters:
epochs = 100
train_batch_size = 1024
learner = adam
learning_rate = 0.0005
train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}
eval_step = 1
stopping_step = 10
clip_grad_norm = None
weight_decay = 0.0
loss_decimal_place = 4

Evaluation Hyper Parameters:
eval_args = {'split': {'RS': [0.8, 0.1, 0.1]}, 'order': 'RO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}
repeatable = True
metrics = ['Recall', 'Precision', 'Hit', 'NDCG', 'ItemCoverage', 'AveragePopularity', 'GiniIndex', 'ShannonEntropy', 'TailPercentage']
topk = [10]
valid_metric = NDCG@10
valid_me

Unnamed: 0,user_id,topk_items
0,1,59534364919158820813114318594
1,2,20281183590341852716103183496081393
2,3,121026011961198127048035615802716110
3,4,119626012101198858121412404805412028
4,5,299726922908285823332599295922323952318


save all_user_top10 successfully


In [3]:
import pandas as pd
import numpy as np
from collections import defaultdict

# Step 1: 加载推荐结果 & 用户性别
topk_df = pd.read_csv('outputs/all_user_top10.csv')
user_df = pd.read_csv('../datasets/atomic_datasets/ml-1m/ml-1m.user', sep='\t')
user2gender = dict(zip(user_df['user_id:token'], user_df['gender:token']))

# Step 2: 加载测试集 ground truth（正反馈）
test_df = pd.read_csv('../datasets/split_datasets/ml-1m/ml-1m.test.inter', sep='\t')
test_df = test_df[test_df['label:float'] == 1.0]
user2ground_truth = test_df.groupby('user_id:token')['item_id:token'].agg(set).to_dict()

# Step 3: NDCG@10 计算函数
def ndcg_at_k(preds, true_items, k=10):
    dcg = 0.0
    for i, item in enumerate(preds[:k]):
        if item in true_items:
            dcg += 1.0 / np.log2(i + 2)
    ideal_len = min(len(true_items), k)
    idcg = sum(1.0 / np.log2(i + 2) for i in range(ideal_len))
    return dcg / idcg if idcg > 0 else 0.0

# Step 4: 分组统计
ndcg_male, ndcg_female, ndcg_all = [], [], []

for _, row in topk_df.iterrows():
    uid = row['user_id']
    pred_items = eval(row['topk_items']) if isinstance(row['topk_items'], str) else row['topk_items']
    true_items = user2ground_truth.get(uid, set())
    gender = user2gender.get(uid, None)

    if gender not in ('M', 'F') or not true_items:
        continue

    ndcg = ndcg_at_k(pred_items, true_items, k=10)
    if gender == 'M':
        ndcg_male.append(ndcg)
        ndcg_all.append(ndcg)
    else:
        ndcg_female.append(ndcg)
        ndcg_all.append(ndcg)

# Step 5: 输出
print(f'NDCG@10 (All):    {np.mean(ndcg_all):.4f} over {len(ndcg_all)} users')
print(f'NDCG@10 (Male):   {np.mean(ndcg_male):.4f} over {len(ndcg_male)} users')
print(f'NDCG@10 (Female): {np.mean(ndcg_female):.4f} over {len(ndcg_female)} users')

NDCG@10 (All):    0.0756 over 6012 users
NDCG@10 (Male):   0.0770 over 4312 users
NDCG@10 (Female): 0.0721 over 1700 users
