一个基于 EfficientNetV2 的植物分类项目,支持 SupCon 预训练 + 分类器微调的双阶段训练流程,并内置 L2‑SP 正则化、Mixup/CutMix/随机擦除等增强策略,以及推理阶段的 TTA(Test‑Time Augmentation)。
特点
- 双阶段训练:
SupCon主干预训练 → 分类器微调。 - 高质量正则化:
L2‑SP(阶段2单锚点;可选阶段3双锚点)。 - 数据增强:随机翻转、对比度/亮度/饱和度、
Random Erasing、Mixup、CutMix。 - 采样策略:
normal / weighted / hybrid,可配置难例类别与权重。 - 推理增强:内置多种
TTA策略(标准/增强/翻转/裁剪)。 - 训练日志与曲线输出:CSV 与 PNG 自动保存。
项目结构
code/:训练与工具脚本train_supcon_base.py:双阶段(SupCon + 分类器)训练主脚本train_refactored.py:集成 TTA 评估的训练脚本train_kfold.py:K‑Fold 训练model.py:EfficientNetV2 模型构建(efficientnetv2_s/m/l)L2_SP_regularization.py:L2‑SP 正则化实现(阶段2/阶段3)supcon_utils.py:SupCon 损失与辅助工具tta_utils.py:TTA 策略集合与接口- 其它:
dataset_process.py、utils.py、可视化相关
model/config.json:项目配置(数据、模型、训练、增强、路径等)label_mapping.json:类别映射
output/:训练日志与曲线(CSV、PNG)model/weight/:权重输出目录
环境与安装
- Python 3.8+,TensorFlow 2.x(见
requirements.txt) - 安装依赖:
pip install -r requirements.txt
数据准备
- 将训练图片放在
data_dir指定目录,并在csv_file指定标签文件。 config.json中data段:label_column: 例如english_namefilename_column: 例如filenameimg_height/img_width: 输入尺寸(如 400x400)
- 验证划分通过
validation_split控制。
配置说明(model/config.json 关键项)
pathsdata_dir、csv_file:训练数据与标签路径model_output_dir、log_output_dir:模型与日志输出目录checkpoint_name、final_model_name、training_log_name、training_curve_name:输出文件命名supcon_checkpoint_name:SupCon 训练的完整模型检查点文件名(不用于分类器初始化)supcon_backbone_path:SupCon 主干权重文件(仅 Backbone;用于分类器初始化,层名可匹配)stage2_checkpoint_path:阶段3双锚点的 Stage2 权重(可选)
modelname:efficientnetv2_s/m/lpretrained: 是否使用 ImageNet 预训练权重(为分类头提供更好初始化)use_your_pretrained_weights_path: 可选自定义权重路径(与模型层名需匹配)freeze_backbone: 是否全程冻结主干frezze_epoch: 前 N 个 epoch 冻结主干,之后自动解冻(优先于freeze_backbone)
trainingbase_batch_size、epochs、supcon_epochsuse_mixed_precision: 是否启用混合精度
augmentationrandom_erasing、mixup、cutmix可独立启用与配置
optimizer.learning_rate_scheduleinitial_lr、min_lr、warmup_epochs、cycle_epochs、classifier_lr_multiplier
Fine‑tuningstage2.lambda_supcon、regularize_backbone_onlystage2.pretrained_weights_path: 分类器初始化权重(为空时回退使用supcon_backbone_path)stage3.lambda_stage2: 双锚点第二锚强度(仅阶段3使用)
训练流程
-
方式一:标准双阶段训练(推荐)
python code/train_supcon_base.py- 行为:
- 加载 ImageNet 预训练(为分类头打底)
- 加载
supcon_backbone_path(覆盖主干;层名匹配) - 应用 L2‑SP(阶段2单锚点):以当前模型权重为锚,正则化主干(可跳过 BN)
- 按
frezze_epoch动态冻结前 N 个 epoch,随后自动解冻并继续联合训练 - 输出检查点、最终权重、训练日志与曲线
-
方式二:训练后自动进行 TTA 评估
python code/train_refactored.py- 训练完成后创建原始验证数据集并进行 TTA 对比评估
- 若
data.image_size未配置,脚本会自动回退使用img_height/img_width(已兼容处理)
-
K‑Fold 训练(如需)
python code/train_kfold.py
推理(含 TTA)
- 基础用法:
python code/predict.py <test_dir> <out_csv>
- 启用 TTA:
python code/predict.py <test_dir> <out_csv> --tta standard--tta选项:none / flip / crop / standard / enhanced- 权重自动搜索于
model/及其子目录;也可将最终权重放在model/weight/下
输出与日志
- 训练日志:
output/<training_log_name>(CSV) - 训练曲线:
output/<training_curve_name>(PNG) - 最终权重:
model/weight/<final_model_name> - SupCon 主干权重:
model/weight/<supcon_backbone_path>(文件名来自paths.supcon_backbone_path)
重要实践与注意事项
- 分类器初始化请使用
supcon_backbone_path(主干权重,层名与分类器一致)。supcon_checkpoint_name为完整模型检查点,层名前缀不同,直接用于分类器会出现by_name匹配不充分、收敛变慢。
- 阶段2(单锚点 L2‑SP)默认以当前模型权重为锚,无需提供
stage2_weights_path。- 仅当进入阶段3(双锚点)时,才提供阶段2检查点路径。
- 动态冻结(
model.frezze_epoch)优先于model.freeze_backbone:- 前 N 个 epoch 冻结主干(含 BN),到第 N 个 epoch 自动解冻并重新
compile,优化器与 LR 调度连续。
- 前 N 个 epoch 冻结主干(含 BN),到第 N 个 epoch 自动解冻并重新
- BN 正则化跳过:
L2_SP_regularization.py支持skip_batchnorm=True,对 BN 的gamma/beta/均值/方差不施加 L2‑SP。
常见问题
- 找不到权重文件:
- 确认
model/weight/下存在目标*.h5,或在predict.py中自定义权重路径。
- 确认
- 层名不匹配导致加载不完全:
- 分类器阶段请确保使用主干权重(
supcon_backbone_path),避免完整检查点层名前缀不一致。
- 分类器阶段请确保使用主干权重(
image_size键不存在:- 已在
train_refactored.py中兼容处理,自动回退使用img_height/img_width。
- 已在
示例配置片段
"paths": {
"data_dir": ".../train",
"csv_file": ".../train_labels.csv",
"model_output_dir": "model/weight",
"log_output_dir": "output",
"supcon_backbone_path": "supcon_backbone_sampler_weighted-weights.h5"
},
"data": {"img_height": 400, "img_width": 400, "validation_split": 0.2},
"model": {"name": "efficientnetv2_m", "pretrained": true, "freeze_backbone": false, "frezze_epoch": 0},
"Fine-tuning": {"stage2": {"lambda_supcon": 0.001, "regularize_backbone_only": true}}
开始使用
- 配置好
model/config.json(至少设置数据与输出路径)。 - 安装依赖:
pip install -r requirements.txt - 训练:
python code/train_supcon_base.py - 查看日志与曲线:
output/目录。 - 推理:
python code/predict.py <test_dir> <out_csv> --tta standard
如需进一步细化日志(例如打印加载到的具体层/参数数量),我可以为加载与正则化过程补充详细统计输出。
基于 EfficientNetV2-M 的深度学习植物分类系统
├── code/ # 源代码
│ ├── train_refactored.py # 训练脚本 ⭐
│ ├── model.py # 模型定义
│ └── utils.py # 工具函数
├── model/ # 模型输出
│ ├── config.json # 配置文件 ⭐
│ └── weight/ # 模型权重
├── output/ # 训练日志
└── check.py # 配置检查
pip install tensorflow pandas scikit-learn matplotlib pillow编辑 model/config.json:
{
"paths": {
"data_dir": "你的数据目录",
"csv_file": "你的CSV文件"
}
}cd code
python train_refactored.py"training": { "epochs": 10 }"training": { "base_batch_size": 4 }"augmentation": { "enabled": false }训练后在 output/ 目录查看:
- 📈
training_curve.png- 训练曲线 - 📊
training_log.csv- 训练数据 - 📝
training_*.log- 详细日志
Q: 显存不足?
"training": { "base_batch_size": 4 }Q: 找不到配置文件?
cd code # 确保在 code 目录运行Q: 训练太慢?
- 检查 GPU 是否启用
- 增加 batch_size(如果显存允许)