一个基于自注意力机制(Self-Attention)的序列推荐模型实现(SASRec)。项目包含数据预处理/缓存、模型训练与验证、指标评估与模型保存的完整流程,可在多数据集上复现并对比结果。
├─ run_sasrec.py # 训练/验证/测试主脚本
├─ models/
│ └─ sasrec.py # SASRec 模型(Transformer 编码 + 打分)
├─ layers/ # 基础层组件
│ ├─ MultiHeadAttention.py
│ ├─ FeedForward.py
│ └─ LayerNorm.py
├─ util/
│ ├─ evaluate.py # Recall/NDCG 等指标
│ └─ util.py # 通用工具(目录/模型保存、seed、设备)
├─ dataset.py # SeqRecDataset/AugSeqRecDataset 与缓存
├─ Instruments/ # 示例数据集(同名数据集放在同级目录)
│ ├─ Instruments.inter.json # 交互数据(用户→物品序列)
│ ├─ Instruments.item.json # 可选物品元数据
│ ├─ cache/SeqRecDataset/* # 预处理缓存 .npy
│ └─ raw/*.json.gz # 原始压缩文件(可选)
└─ save/ # 训练产物保存根目录
- Python ≥ 3.8
- PyTorch ≥ 1.12(建议使用 CUDA 环境)
- 主要依赖:numpy, pandas, tqdm, scipy
示例安装(可按需调整版本):
pip install torch numpy pandas tqdm scipy每个数据集对应一个同名目录,包含至少以下文件:
<DATASET>/<DATASET>.inter.json:用户到物品的交互序列,JSON 格式,如:
{
"0": [12, 7, 35, 9],
"1": [3, 5, 6]
}说明与约定:
- 物品 ID 从 0 开始;内部会加 1,以便保留 0 作为 padding。
- 会自动生成缓存文件:
<DATASET>/cache/SeqRecDataset/{seed}-{max_len}-{mode}.npy。
项目已附带示例数据集目录 Instruments/,可直接运行进行验证。
在项目根目录运行(Windows PowerShell 与 Linux/Mac 仅参数相同):
python run_sasrec.py --device cuda:0 --dataset Instruments --data_path ./ --epochs 50首轮运行会根据 max_len/seed/mode 预处理并缓存数据,后续复用缓存以加速加载。
- 数据与设备
--data_path:数据根路径(默认./)--dataset:数据集名称,可选Health|Phones|Prime|Sports|Instruments(默认Instruments)--device:如cuda:0或cpu
- 模型结构
--d_model:嵌入维度(默认 128)--n_heads:注意力头数(默认 1)--num_layers:Transformer 层数(默认 2)--max_len:序列最大长度(默认 20)
- 正则与 Dropout
--emb_dropout、--attn_dropout、--ffn_dropout--wd:权重衰减(默认 5e-1)
- 训练与调度
--epochs、--train_batch_size/--valid_batch_size/--test_batch_size--lr:学习率(默认 5e-4)--scheduler_type:cosine|linear|none(默认none)--warmup_ratio:预热步数占比(默认 0.01)
- 验证/早停
--eval_step:验证间隔(默认 1 epoch)--eval_metric:Recall@5|NDCG@5|loss(默认Recall@5)--early_stop_step:无提升的耐心步数(默认 20)
- 保存
--save_root_path:保存根目录(默认./save/)--params_in_model_save_title:参与命名的超参键名列表
查看全部参数:
python run_sasrec.py -h- 每次运行会在
save/<DATASET>/下保存最优模型(根据--eval_metric选择最大化或最小化)。 - 文件名包含若干核心超参,便于对比试验。
- 在验证阶段会同步打印测试集的 Recall/NDCG@K(由
--metrics --topk控制)。
- 数据集:
SeqRecDataset- 训练:从每个用户序列切片构造
(his_seqs, next_item)样本 - 验证/测试:使用最后一次/整体切片,保证评估一致性
- 自动缓存到
.npy,再次加载时直读缓存
- 训练:从每个用户序列切片构造
- 模型:
SASRec- 物品嵌入 + 位置嵌入 → 多层 Transformer(带因果 mask)
- 取最后一个有效位置向量与所有物品嵌入做矩阵乘,得到打分
- 损失:交叉熵(对所有物品进行分类)
- 评估:
util/evaluate.py- 提供
Recall@K,NDCG@K等指标(测试时进行加权聚合)
- 提供
可参考 run_sasrec.sh 中的网格搜索注释,覆盖不同学习率/权重衰减/数据集组合。将命令复制到终端执行即可(Linux 可用 nohup 后台运行)。
- CUDA 不可用:将
--device设为cpu或安装匹配 CUDA 版本的 PyTorch。 - 文件找不到:确保数据放置为
<data_path>/<dataset>/<dataset>.inter.json的结构。 - 评估为 0:请确认
inter.json中每个用户至少有 2 条交互(验证需要)。
- SASRec: Self-Attentive Sequential Recommendation (WWW 2018)