Skip to content

Ximi-GitHub/RF2C

Repository files navigation

RF2C (Random Forest to C)

EMG 信号特征提取与随机森林嵌入式 C 代码生成工具链

Python scikit-learn Treelite

简介工作流概览脚本详细用法快速开始

## 简介

RF2C 是一个专为嵌入式设备(MCU)设计的完整机器学习工作流。它能够将采集到的多通道(默认 8 通道)表面肌电(sEMG)信号提取为 MAV(平均绝对值)特征,对特征数据进行数据集划分、随机森林超参数网格搜索,最终将训练好的模型导出为可以直接部署到微控制器上的纯 C 代码(基于 Treelite)。

核心特性

  • 🌟 自适应特征提取:自动标定静息阈值,支持多通道投票、时间约束过滤的严格活跃窗口提取。
  • 📊 数据集管理:提供稳定的一键式训练/测试集随机划分工具。
  • 🔍 自动化调参:内置交叉验证与网格搜索,支持按模型准确率和体积排序评估。
  • 🚀 嵌入式友好:自动分析树结构、Flash/RAM 占用估算,一键生成 model_data.h 和预测 C 代码。

┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│  原始 EMG 数据  │ ──>│ 特征提取 (MAV)  │ ──>│  划分数据集     │ ──>│ 网格搜索寻优  │
│  (.csv 文件)    │    │ emg_mav_cli.py  │    │ split_dataset...│    │ grid_search...  │
└─────────────────┘    └─────────────────┘    └─────────────────┘    └────────┬────────┘
                                                                              │ (最优参数)
                                                                              ▼
                                                                     ┌─────────────────┐
                                                                     │ 训练并导出 C    │
                                                                     │ train_export... │
                                                                     └─────────────────┘

1. emg_mav_cli.py (特征提取)

功能:对原始的 8 通道 EMG 信号计算滑动窗口 MAV(平均绝对值)特征。自动利用静息文件(label=0)计算阈值,并在真实手势文件上根据阈值、通道投票及时间长度约束,严格提取有效动作窗口。

主要参数

  • --input-dir: 原始数据 CSV 文件或目录(需符合命名规范 label_round_gesturedata_time.csv)。
  • --window-length: 滑动窗口长度(采样点)。
  • --step: 滑动步长(采样点)。
  • --output-dir: 结果输出路径(生成特征 CSV、阈值 JSON 及可视化图表)。
  • --mav-scale, --vote-channels, --min-active-sec, --max-gap-sec, --shrink-sec: 有效窗口判定的高阶调优参数。

示例

python emg_mav_cli.py --input-dir ./raw_data --window-length 200 --step 50 --output-dir ./features

2. split_dataset_cli.py (数据集划分)

功能:将特征提取输出的 CSV 文件按指定比例随机划分为训练集(Train)和测试集(Test)。

主要参数

  • --input-file: 特征提取步骤生成的 CSV 文件路径。
  • --test-ratio: 测试集所占比例(默认 0.2)。
  • --output-dir: 划分后文件的保存目录。
  • --seed: 随机种子,保证每次划分可复现(默认 42)。

示例

python split_dataset_cli.py --input-file ./features/extracted_features.csv --test-ratio 0.2 --output-dir ./dataset

3. grid_search_rf.py (超参数网格搜索)

功能:在划分好的训练集上运行 K-Fold 交叉验证,尝试多种随机森林超参数组合,寻找综合准确率最高的模型配置。

主要参数

  • --train-csv, --test-csv: 训练集和测试集路径。
  • --output-dir: 搜索结果输出目录。
  • --n-estimators, --max-depth, --min-samples-split, --min-samples-leaf, --max-features: 随机森林搜索空间(支持逗号分隔如 100,200,300,或范围如 10:30:10)。

示例

python grid_search_rf.py --train-csv ./dataset/train_dataset.csv --test-csv ./dataset/test_dataset.csv --output-dir ./grid_search_out --n-estimators "50,100,200" --max-depth "10,20,None"

4. train_export_rf.py (训练与导出)

功能:根据网格搜索选出的最佳参数(或手动指定的参数),训练最终的随机森林模型,并将其转换为嵌入式 C 代码(Treelite),同时估算在 MCU 上的 Flash 和 RAM 资源占用。

主要参数

  • --train-csv, --test-csv: 数据集路径。
  • --params-json: grid_search_rf.py 生成的 gridsearch_summary.json 文件路径。
  • --output-dir: C 代码及模型评估报告的输出目录。
  • --quantize: (可选)启用 INT8 量化以进一步压缩模型体积。

示例

python train_export_rf.py --train-csv ./dataset/train_dataset.csv --test-csv ./dataset/test_dataset.csv --params-json ./grid_search_out/gridsearch_summary.json --output-dir ./c_model_export

以下是走通完整数据流的最简示例:

# 1. 提取特征 (滑动窗口200, 步长20)
python emg_mav_cli.py --input-dir .\data\dataset1 --window-length 200 --step 20 --sampling-rate 1000 --mav-scale 1 --vote-channels 3 --min-active-sec 1 --max-gap-sec 0.1 --shrink-sec 0.2 --output-dir .\data\dataset1\output_v5 --workers 8 --use-existing-threshold

# 2. 划分数据集 (80% 训练, 20% 测试)
python split_dataset_cli.py --input-file .\data\dataset1\extracted_features.csv --test-ratio 0.2 --output-dir .\data\dataset1\output_v5\split_output

# 3. 寻找最优超参数
C:\Users\zyx\AppData\Local\conda\conda\envs\IPC_project\python.exe RF2C/grid_search_rf.py --train-csv .\data\dataset1\split_output\train_dataset.csv --test-csv .\data\dataset1\split_output\test_dataset.csv --output-dir .\data\dataset1\output_v5\grid_search_output --n-estimators 5:20:1 --max-depth 10:30:1

# 4. 训练模型并导出 C 代码
C:\Users\zyx\AppData\Local\conda\conda\envs\IPC_project\python.exe train_export_rf.py --train-csv .\data\dataset1\split_output\train_dataset.csv --test-csv .\data\dataset1\split_output\test_dataset.csv --output-dir .\data\dataset1\output_v5\c_model_export --n-estimators 13 --max-depth 16 --min-samples-split 2 --min-samples-leaf 1 --max-features sqrt

About

专为 MCU 设计的 sEMG 机器学习工具链:提供多通道肌电 MAV 特征提取、随机森林模型寻优及一键导出纯 C 代码功能。

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages