Skip to content

BUAA-SASRec/SASRec

Repository files navigation

SASRec 序列推荐(PyTorch)

一个基于自注意力机制(Self-Attention)的序列推荐模型实现(SASRec)。项目包含数据预处理/缓存、模型训练与验证、指标评估与模型保存的完整流程,可在多数据集上复现并对比结果。

目录结构

├─ run_sasrec.py          # 训练/验证/测试主脚本
├─ models/
│  └─ sasrec.py           # SASRec 模型(Transformer 编码 + 打分)
├─ layers/                # 基础层组件
│  ├─ MultiHeadAttention.py
│  ├─ FeedForward.py
│  └─ LayerNorm.py
├─ util/
│  ├─ evaluate.py         # Recall/NDCG 等指标
│  └─ util.py             # 通用工具(目录/模型保存、seed、设备)
├─ dataset.py             # SeqRecDataset/AugSeqRecDataset 与缓存
├─ Instruments/           # 示例数据集(同名数据集放在同级目录)
│  ├─ Instruments.inter.json         # 交互数据(用户→物品序列)
│  ├─ Instruments.item.json          # 可选物品元数据
│  ├─ cache/SeqRecDataset/*          # 预处理缓存 .npy
│  └─ raw/*.json.gz                  # 原始压缩文件(可选)
└─ save/                 # 训练产物保存根目录

环境依赖

  • Python ≥ 3.8
  • PyTorch ≥ 1.12(建议使用 CUDA 环境)
  • 主要依赖:numpy, pandas, tqdm, scipy

示例安装(可按需调整版本):

pip install torch numpy pandas tqdm scipy

数据准备

每个数据集对应一个同名目录,包含至少以下文件:

  • <DATASET>/<DATASET>.inter.json:用户到物品的交互序列,JSON 格式,如:
{
  "0": [12, 7, 35, 9],
  "1": [3, 5, 6]
}

说明与约定:

  • 物品 ID 从 0 开始;内部会加 1,以便保留 0 作为 padding。
  • 会自动生成缓存文件:<DATASET>/cache/SeqRecDataset/{seed}-{max_len}-{mode}.npy

项目已附带示例数据集目录 Instruments/,可直接运行进行验证。

快速开始

在项目根目录运行(Windows PowerShell 与 Linux/Mac 仅参数相同):

python run_sasrec.py --device cuda:0 --dataset Instruments --data_path ./ --epochs 50

首轮运行会根据 max_len/seed/mode 预处理并缓存数据,后续复用缓存以加速加载。

常用参数(节选)

  • 数据与设备
    • --data_path:数据根路径(默认 ./
    • --dataset:数据集名称,可选 Health|Phones|Prime|Sports|Instruments(默认 Instruments
    • --device:如 cuda:0cpu
  • 模型结构
    • --d_model:嵌入维度(默认 128)
    • --n_heads:注意力头数(默认 1)
    • --num_layers:Transformer 层数(默认 2)
    • --max_len:序列最大长度(默认 20)
  • 正则与 Dropout
    • --emb_dropout--attn_dropout--ffn_dropout
    • --wd:权重衰减(默认 5e-1)
  • 训练与调度
    • --epochs--train_batch_size/--valid_batch_size/--test_batch_size
    • --lr:学习率(默认 5e-4)
    • --scheduler_typecosine|linear|none(默认 none
    • --warmup_ratio:预热步数占比(默认 0.01)
  • 验证/早停
    • --eval_step:验证间隔(默认 1 epoch)
    • --eval_metricRecall@5|NDCG@5|loss(默认 Recall@5
    • --early_stop_step:无提升的耐心步数(默认 20)
  • 保存
    • --save_root_path:保存根目录(默认 ./save/
    • --params_in_model_save_title:参与命名的超参键名列表

查看全部参数:

python run_sasrec.py -h

训练日志与模型保存

  • 每次运行会在 save/<DATASET>/ 下保存最优模型(根据 --eval_metric 选择最大化或最小化)。
  • 文件名包含若干核心超参,便于对比试验。
  • 在验证阶段会同步打印测试集的 Recall/NDCG@K(由 --metrics --topk 控制)。

数据管线与模型概览

  • 数据集:SeqRecDataset
    • 训练:从每个用户序列切片构造 (his_seqs, next_item) 样本
    • 验证/测试:使用最后一次/整体切片,保证评估一致性
    • 自动缓存到 .npy,再次加载时直读缓存
  • 模型:SASRec
    • 物品嵌入 + 位置嵌入 → 多层 Transformer(带因果 mask)
    • 取最后一个有效位置向量与所有物品嵌入做矩阵乘,得到打分
    • 损失:交叉熵(对所有物品进行分类)
  • 评估:util/evaluate.py
    • 提供 Recall@K, NDCG@K 等指标(测试时进行加权聚合)

复现实验(示例)

可参考 run_sasrec.sh 中的网格搜索注释,覆盖不同学习率/权重衰减/数据集组合。将命令复制到终端执行即可(Linux 可用 nohup 后台运行)。

常见问题

  • CUDA 不可用:将 --device 设为 cpu 或安装匹配 CUDA 版本的 PyTorch。
  • 文件找不到:确保数据放置为 <data_path>/<dataset>/<dataset>.inter.json 的结构。
  • 评估为 0:请确认 inter.json 中每个用户至少有 2 条交互(验证需要)。

参考

  • SASRec: Self-Attentive Sequential Recommendation (WWW 2018)

About

基于自注意力机制(Self-Attention)的序列推荐模型实现(SASRec)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors