In [1]:
import torch
import yaml
from tqdm import tqdm
import numpy as np
from data import BPRSampleGenerator, SeqBPRDataset
from torch.utils.data import DataLoader
from model import SeqLearn

In [2]:
with open("/graduation_design/bpr/config/bpr.yaml", 'r', encoding='utf-8') as f:
    args = yaml.unsafe_load(f)
args

{'base_path': '/graduation_design/',
 'topk': 10,
 'data': {'train_valid_split': 0.95,
  'maxlen': 30,
  'name': 'ml-1m',
  'sep': '::',
  'item_path': '/graduation_design/data/ml-1m/movies.dat',
  'item_emb_path': '/graduation_design/data/ml-1m/item_embeddings.npy',
  'path': '/graduation_design/data/ml-1m/ratings.dat',
  'num_negatives': 1,
  'user_threshold': 10,
  'item_threshold': 10,
  'rating_threshold': 2,
  'base_model': ['acf', 'fdsa', 'harnn', 'caser', 'pfmc', 'sasrec', 'anam'],
  'base_model_path': '/graduation_design/base_model_results/ml-1m'},
 'model': {'lr': 0.001,
  'type': 'SASEM',
  'lamda': 1e-05,
  'hidden_dim': 32,
  'device': 'cuda:0',
  'optimizer': 'AdamOptimizer',
  'tradeoff': 2,
  'div_module': 'cov',
  'pretrain_llm': 'bert-base-uncased'},
 'epoch': 1,
 'batch_size': 512}

In [9]:
acf = np.load(args['data']['base_model_path'] + f"/acf.npy")

FileNotFoundError: [Errno 2] No such file or directory: '/graduation_design/base_model_results/ml-1m/acf.npy'

In [11]:
# 创建数据生成器
generator = BPRSampleGenerator(args['data'])
seq_samples = generator.generate_seq_samples(
    seq_len=args['data']['maxlen'],
    num_negatives=args['data']['num_negatives']
)

# 创建数据集
dataset = SeqBPRDataset(seq_samples, args['model']['device'])
train_size = int(args['data']['train_valid_split'] * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=True)

>>>> 交互索引范围: 最小值 = 0, 最大值 = 834448
>>>> 数据加载完成: 834449 条交互, 6033 个用户, 3123 个物品
>>>> 基模型的预测结果加载完成: (834449, 7, 102)
>>>> 构建了 6033 个用户的历史交互序列


>>>> 生成序列样本: 100%|██████████| 6033/6033 [03:35<00:00, 27.99it/s]


>>>> 生成了 828416 个序列感知BPR样本对


In [5]:
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False, drop_last=True)

In [12]:
model = np.load(args['data']['base_model_path'] + f"/sasrec.npy")

with torch.no_grad():
    ndcg_scores = []
    for batch in tqdm(test_loader, desc="计算NDCG@10..."):
        users, user_seq, pos_items, neg_items, base_model_preds = batch
        
        # 遍历batch中的每个样本
        for i in range(len(users)):
            # 获取交互索引
            user_id = users[i].item()
            pos_item_id = pos_items[i].item()
            interaction_idx = generator.get_interaction_index(user_id, pos_item_id)
            if interaction_idx is None:
                continue

            # 获取模型推荐的top物品列表
            top_items = model[interaction_idx][2:2+args['topk']]

            # 获取用户的实际交互物品
            true_items = generator.user_interacted_items[user_id]

            # 计算DCG
            dcg = 0
            for j, item_idx in enumerate(top_items):
                if item_idx in true_items:
                    dcg += 1 / np.log2(j + 2)
            
            # 计算IDCG
            idcg = 0
            for j in range(min(len(true_items), args['topk'])):
                idcg += 1 / np.log2(j + 2)

            # 计算NDCG
            ndcg = dcg / idcg if idcg > 0 else 0
            ndcg_scores.append(ndcg)

np.mean(ndcg_scores)

计算NDCG@10...: 100%|██████████| 41421/41421 [00:15<00:00, 2602.59it/s]


0.22138667819592134

## 集成模型预测结果

In [14]:
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False, drop_last=True)

model = SeqLearn(args['model'], args['data'], 6033, generator.n_item)
# 加载checkpoint
ckpt = torch.load(f"/graduation_design/bpr/ckpt/bpr_epoch1_batch60.pth")

# 过滤掉不需要加载的层
filtered_ckpt = {k: v for k, v in ckpt.items() if not k.startswith('item_tower.cex')}

# 加载过滤后的权重
model.load_state_dict(filtered_ckpt, strict=False)

model.eval()
with torch.no_grad():
    ndcg_scores = []
    for batch in tqdm(test_loader, desc="计算测试集NDCG"):
        users, user_seq, pos_items, neg_items, base_model_preds = batch

        # 获取所有物品的预测分数
        all_items = torch.arange(len(generator.item_to_id)).to(args['model']['device'])
        all_scores = model.predict(users, user_seq, all_items, base_model_preds)

        _, indices = torch.topk(all_scores, args['topk'])

        for user_idx in range(len(users)):
            # 获取用户的实际交互物品
            true_items = generator.user_interacted_items[users[user_idx].item()]

            # 计算DCG
            dcg = 0
            for i, item_idx in enumerate(indices[user_idx]):
                if item_idx.item() in true_items:
                    dcg += 1 / np.log2(i + 2)

            # 计算IDCG
            idcg = 0
            for i in range(min(len(true_items), args['topk'])):
                idcg += 1 / np.log2(i + 2)

            # 计算NDCG
            ndcg = dcg / idcg if idcg > 0 else 0
            ndcg_scores.append(ndcg)

np.mean(ndcg_scores)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


>>>> 加载预计算的物品嵌入...


计算测试集NDCG:   0%|          | 0/80 [00:00<?, ?it/s]

user_emb time: 8.0108642578125e-05
user_interaction time: 1.891808032989502
preference time: 0.0018796920776367188
base_model_focus_llm time: 0.0003285408020019531
each_model_emb time: 6.985664367675781e-05
basemodel_emb time: 3.314018249511719e-05
wgts time: 5.936622619628906e-05


计算测试集NDCG:   1%|▏         | 1/80 [00:02<03:18,  2.51s/it]

item_emb time: 0.4733603000640869
weighted_basemodel_emb time: 0.0003726482391357422
weighted_basemodel_emb_expanded time: 4.5299530029296875e-06
item_emb_expanded time: 1.3828277587890625e-05
concat time: 3.1948089599609375e-05
scores time: 0.00019097328186035156
user_emb time: 6.008148193359375e-05
user_interaction time: 1.8978071212768555
preference time: 0.0018105506896972656
base_model_focus_llm time: 0.0003025531768798828
each_model_emb time: 6.508827209472656e-05
basemodel_emb time: 3.314018249511719e-05
wgts time: 6.198883056640625e-05


计算测试集NDCG:   2%|▎         | 2/80 [00:05<03:16,  2.52s/it]

item_emb time: 0.47461700439453125
weighted_basemodel_emb time: 0.0004582405090332031
weighted_basemodel_emb_expanded time: 4.291534423828125e-06
item_emb_expanded time: 1.5974044799804688e-05
concat time: 3.123283386230469e-05
scores time: 0.00021028518676757812
user_emb time: 6.413459777832031e-05
user_interaction time: 1.892831563949585
preference time: 0.0018486976623535156
base_model_focus_llm time: 0.0003120899200439453
each_model_emb time: 6.508827209472656e-05
basemodel_emb time: 3.147125244140625e-05
wgts time: 5.984306335449219e-05


计算测试集NDCG:   4%|▍         | 3/80 [00:07<03:13,  2.52s/it]

item_emb time: 0.47347211837768555
weighted_basemodel_emb time: 0.0003848075866699219
weighted_basemodel_emb_expanded time: 4.0531158447265625e-06
item_emb_expanded time: 1.33514404296875e-05
concat time: 3.075599670410156e-05
scores time: 0.00018095970153808594
user_emb time: 7.534027099609375e-05


计算测试集NDCG:   4%|▍         | 3/80 [00:08<03:39,  2.85s/it]


KeyboardInterrupt: 