基于 CASIA-B 轮廓数据集的步态识别示例项目,参考 GaitSet 等常见 silhouette-based 方法:先对单帧轮廓序列做卷积编码,再通过集合池化(Mean/Max + Temporal Pyramid Pooling)获得时序无关的表示,最终完成身份分类。
data/GaitDatasetB-silh/
├── 001/ # 被试编号(001-124)
│ ├── nm-01/ # 行走状态:正常(nm,01-06)、背包(bg,01-02)、大衣(cl,01-02)
│ │ ├── 000/ # 拍摄角度:000, 018, ..., 180
│ │ │ ├── 001-nm-01-000-001.png
│ │ │ └── ... (<=150帧)
│ └── ...
└── ...
默认从 data/GaitDatasetB-silh 自动发现 001-124 的 subject 文件夹,可通过配置文件自定义训练/验证划分。
config/config.yaml # 训练超参与数据配置
src/data/casia_b.py # CASIA-B 数据集读取、划分工具
src/model/gaitset.py # Silhouette 编码器 + Set Pooling + 分类头
src/train.py # 训练/验证循环、日志与 checkpoint
src/inference.py # 单样本推理脚本
src/evaluate.py # 验证集评估脚本 (Accuracy, Recall, F1)
src/utility/unzip_dataset.py # 数据解压小工具
README.md # 使用说明
docs/architecture.svg # 模型结构示意图
python -m venv .venv && source .venv/bin/activate # 推荐
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 # 按需选择CUDA/CPU版本
pip install pyyaml pillow scikit-learn tensorboard若需批量解压数据,可运行
python -m src.utility.unzip_dataset --data-dir data/raw_archives。
- Frame Encoder:3 层 3x3 Conv+BN+ReLU + 2x2 下采样,得到每帧 128 维特征。
- Set Pooling:全局 Mean/Max + Temporal Pyramid Pooling(1,2,4 bin),对任意长度序列生成固定维度表示。
- Classifier:Dropout + 全连接输出 124 类 softmax。
config/config.yaml 示例:
experiment:
seed: 42
output_dir: runs/casia-b-baseline
data:
root: data/GaitDatasetB-silh
train_ratio: 0.85
split_seed: 7
frames_per_clip: 30
min_frames: 8
sampling_strategy: uniform
batch_size: 8
num_workers: 4
model:
in_channels: 1
frame_feature_dims: [32, 64, 128]
pyramid_bins: [1, 2, 4]
dropout: 0.3
optim:
lr: 0.0003
weight_decay: 0.0001
epochs: 50
label_smoothing: 0.0- 若提供
data.train_subjects/data.val_subjects(列表形式,如[1,2,3]),将覆盖train_ratio自动划分。 frames_per_clip控制序列截取长度(不足会重复采样,超出按均匀/随机策略抽取)。
python -m src.train --config config/config.yaml日志包含每个 epoch 的 loss/acc,runs/casia-b-baseline/ 下会保存:
latest.pt:最近一次 checkpointbest.pt:验证精度最佳模型history.json:训练/验证曲线数据tensorboard/:TensorBoard 日志
tensorboard --logdir runs/casia-b-baseline/tensorboard然后在浏览器访问 http://localhost:6006。
对单个包含轮廓序列的文件夹进行推理:
python -m src.inference \
--config config/config.yaml \
--checkpoint runs/casia-b-baseline/best.pt \
--input_dir data/GaitDatasetB-silh/001/nm-01/090脚本将输出预测的类别索引(对应 Subject ID)。
在验证集上计算详细指标(Accuracy, Recall, F1-Score):
python -m src.evaluate \
--config config/config.yaml \
--checkpoint runs/casia-b-baseline/best.pt输出示例:
Classification Report:
precision recall f1-score support
0 0.95 0.90 0.92 20
...
accuracy 0.92 200
macro avg 0.91 0.90 0.90 200
weighted avg 0.92 0.92 0.92 200
Evaluation Results:
Accuracy: 0.9200
Recall: 0.9000
F1 Score: 0.9000
- 修改
model.frame_feature_dims可以加深/加宽卷积骨干。 - 调整
data.sampling_strategy为random以提升数据增广效果。 - 若需要额外评估脚本,可在
src/中添加eval.py复用相同数据与模型模块。