-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
85 lines (73 loc) · 3.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: UTF-8 -*-
'''
@Project :CNN_LSTM
@File :train.py
@IDE :PyCharm
@Author :XinYi Huang
'''
import os
import torch
from torch.nn import functional as F
from yolo import YOLO
from _utils.generate import Generator
from configure import config as cfg
if __name__ == '__main__':
Yolo = YOLO(input_shape=cfg.input_size,
anchors=cfg.anchors,
classes_names=cfg.class_names,
learning_rate=cfg.learning_rate,
score_thresh=cfg.score,
iou_thresh=cfg.iou,
max_boxes=cfg.max_boxes,
letterbox=cfg.letterbox,
weight_decay=cfg.weight_decay,
resume_train=cfg.remain_train,
ckpt_path=cfg.ckpt_path + "\\模型文件")
data_gen = Generator(annotation_path=cfg.annotation_path,
input_size=cfg.input_size,
batch_size=cfg.batch_size,
train_split=cfg.train_split,
anchors=cfg.anchors,
num_classes=cfg.class_names.__len__())
train_gen = data_gen.generate(training=True)
validate_gen = data_gen.generate(training=False)
for epoch in range(cfg.Epoches):
for i in range(data_gen.get_train_len()):
sources, targets = next(train_gen)
Yolo.train(sources, targets)
if not (i + 1) % cfg.per_sample_interval:
Yolo.generate_sample(sources, i+1)
print('Epoch{:0>3d} '
'train loss is {:.3f} '
'train acc is {:.3f} '
'train conf acc is {:.3f} '
'train f1 score is {:.3f}'.format(epoch+1,
Yolo.train_loss / (i + 1),
Yolo.train_acc / (i + 1) * 100,
Yolo.train_conf_acc / (i + 1) * 100,
Yolo.train_f1_score / (i + 1) * 100))
torch.save({'state_dict': Yolo.model.state_dict(),
'loss': Yolo.train_loss / (i + 1),
'acc': Yolo.train_acc / (i + 1) * 100},
cfg.ckpt_path + '\\Epoch{:0>3d}_train_loss{:.3f}_train_acc{:.3f}.pth.tar'.format(
epoch + 1, Yolo.train_loss / (i + 1), Yolo.train_acc / (i + 1) * 100))
Yolo.train_loss = 0
Yolo.train_acc = 0
Yolo.train_conf_acc = 0
Yolo.train_f1_score = 0
for i in range(data_gen.get_val_len()):
sources, targets = next(validate_gen)
Yolo.validate(sources, targets)
print('Epoch{:0>3d} '
'validate loss is {:.3f} '
'validate acc is {:.3f} '
'validate conf acc is {:.3f} '
'validate f1 score is {:.3f}'.format(epoch+1,
Yolo.val_loss / (i + 1),
Yolo.val_acc / (i + 1) * 100,
Yolo.val_conf_acc / (i + 1) * 100,
Yolo.val_f1_score / (i + 1) * 100))
Yolo.val_loss = 0
Yolo.val_acc = 0
Yolo.val_conf_acc = 0
Yolo.val_f1_score = 0