基于Tensorflow实现CenterNet
├──net:存放centernet的基本网络结构代码,包括resnet、hourglass以及loss
├──checkpoint:存放模型训练的checkpoint
├──train_dataset:训练数据集
├──Annotation:保存目标的标记
├──ImageSets:记录训练集、测试集和验证集
├──JPEGImages:图像
├──train.txt:记录训练集,不同于ImageSets里的train.txt,该文件保存了训练集图像的位置,目标信息
├──val.txt:同train.txt
├──test.txt:同train.txt
└──train.names:类别名称
├──data:保存基础数据,包括类别名称的文件、字体等
├──evaluate:对模型进行评估代码
├──logs:训练日志
├──models:训练保存的checkpoint所在的文件夹
├──utils:一些基础方法:如dataloader、callbacks、fit等
├──config.py:配置文件
├──inference.py:推理文件
└──train.py:训练文件
在config.py
中设置好训练集路径以及配置好训练的参数之后,执行train.py
文件即可开始训练。
- 在
config.py
配置好你的推理参数,然后运行inference.py
,查看推理结果 - 在evaluate文件夹里面有cal_map.py文件对模型的性能进行评估。
- 20221223:增加了pytorch分支,基于pytorch实现CenterNet