-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
217 lines (176 loc) · 9.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from __future__ import division
from nets.frcnn import get_model
from nets.frcnn_training import cls_loss,smooth_l1,Generator,get_img_output_length,class_loss_cls,class_loss_regr
from utils.config import Config
from utils.utils import BBoxUtility
from utils.roi_helpers import calc_iou
from keras.utils import generic_utils
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
import keras
import numpy as np
import time
import tensorflow as tf
from utils.anchors import get_anchors
def write_log(callback, names, logs, batch_no):
for name, value in zip(names, logs):
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value
summary_value.tag = name
callback.writer.add_summary(summary, batch_no)
callback.writer.flush()
#----------------------------------------------------#
if __name__ == "__main__":
config = Config()
NUM_CLASSES = 3
# 训练100世代
EPOCH = 100
# 开始使用1e-4训练,每过10个世代降低为原来的1/2
Learning_rate = 1e-4
bbox_util = BBoxUtility(overlap_threshold=config.rpn_max_overlap,ignore_threshold=config.rpn_min_overlap)
annotation_path = r"coin_data\train_flip.txt"
model_rpn, model_classifier,model_all = get_model(config,NUM_CLASSES)
base_net_weights = "model_data/voc_weights.h5"
model_all.summary()
model_rpn.load_weights(base_net_weights,by_name=True)
model_classifier.load_weights(base_net_weights,by_name=True)
with open(annotation_path) as f:
lines = f.readlines()
np.random.seed(10101)
np.random.shuffle(lines)
np.random.seed(None)
# 每个世代训练数据集长度的步数
# 根据数据集大小进行指定
EPOCH_LENGTH = len(lines)
gen = Generator(bbox_util, lines, NUM_CLASSES, solid=True)
rpn_train = gen.generate()
log_dir = "logs_flip"
# 训练参数设置
logging = TensorBoard(log_dir=log_dir)
callback = logging
callback.set_model(model_all)
model_rpn.compile(loss={
'regression' : smooth_l1(),
'classification': cls_loss()
},optimizer=keras.optimizers.Adam(lr=Learning_rate)
)
model_classifier.compile(loss=[
class_loss_cls,
class_loss_regr(NUM_CLASSES-1)
],
metrics={'dense_class_{}'.format(NUM_CLASSES): 'accuracy'},optimizer=keras.optimizers.Adam(lr=Learning_rate)
)
model_all.compile(optimizer='sgd', loss='mae')
# 初始化参数
iter_num = 0
train_step = 0
losses = np.zeros((EPOCH_LENGTH, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()
# 最佳loss
best_loss = np.Inf
# 数字到类的映射
print('Starting training')
for i in range(EPOCH):
if i % 10 == 0 and i != 0:
model_rpn.compile(loss={
'regression' : smooth_l1(),
'classification': cls_loss()
},optimizer=keras.optimizers.Adam(lr=Learning_rate/2)
)
model_classifier.compile(loss=[
class_loss_cls,
class_loss_regr(NUM_CLASSES-1)
],
metrics={'dense_class_{}'.format(NUM_CLASSES): 'accuracy'},optimizer=keras.optimizers.Adam(lr=Learning_rate/2)
)
Learning_rate = Learning_rate/2
print("Learning rate decrease to " + str(Learning_rate))
progbar = generic_utils.Progbar(EPOCH_LENGTH)
print('Epoch {}/{}'.format(i + 1, EPOCH))
for iteration, batch in enumerate(rpn_train):
if len(rpn_accuracy_rpn_monitor) == EPOCH_LENGTH and config.verbose:
mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
rpn_accuracy_rpn_monitor = []
print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, EPOCH_LENGTH))
if mean_overlapping_bboxes == 0:
print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
X, Y, boxes = batch[0],batch[1],batch[2]
loss_rpn = model_rpn.train_on_batch(X,Y)
write_log(callback, ['rpn_cls_loss', 'rpn_reg_loss'], loss_rpn, train_step)
P_rpn = model_rpn.predict_on_batch(X)
height,width,_ = np.shape(X[0])
anchors = get_anchors(get_img_output_length(width,height),width,height)
# 将预测结果进行解码
results = bbox_util.detection_out(P_rpn,anchors,1, confidence_threshold=0)
R = results[0][:, 2:]
X2, Y1, Y2, IouS = calc_iou(R, config, boxes[0], width, height, NUM_CLASSES)
if X2 is None:
rpn_accuracy_rpn_monitor.append(0)
rpn_accuracy_for_epoch.append(0)
continue
neg_samples = np.where(Y1[0, :, -1] == 1)
pos_samples = np.where(Y1[0, :, -1] == 0)
if len(neg_samples) > 0:
neg_samples = neg_samples[0]
else:
neg_samples = []
if len(pos_samples) > 0:
pos_samples = pos_samples[0]
else:
pos_samples = []
rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))
if len(neg_samples)==0:
continue
if len(pos_samples) < config.num_rois//2:
selected_pos_samples = pos_samples.tolist()
else:
selected_pos_samples = np.random.choice(pos_samples, config.num_rois//2, replace=False).tolist()
try:
selected_neg_samples = np.random.choice(neg_samples, config.num_rois - len(selected_pos_samples), replace=False).tolist()
except:
selected_neg_samples = np.random.choice(neg_samples, config.num_rois - len(selected_pos_samples), replace=True).tolist()
sel_samples = selected_pos_samples + selected_neg_samples
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
write_log(callback, ['detection_cls_loss', 'detection_reg_loss', 'detection_acc'], loss_class, train_step)
losses[iter_num, 0] = loss_rpn[1]
losses[iter_num, 1] = loss_rpn[2]
losses[iter_num, 2] = loss_class[1]
losses[iter_num, 3] = loss_class[2]
losses[iter_num, 4] = loss_class[3]
train_step += 1
iter_num += 1
progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))])
if iter_num == EPOCH_LENGTH:
loss_rpn_cls = np.mean(losses[:, 0])
loss_rpn_regr = np.mean(losses[:, 1])
loss_class_cls = np.mean(losses[:, 2])
loss_class_regr = np.mean(losses[:, 3])
class_acc = np.mean(losses[:, 4])
mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
rpn_accuracy_for_epoch = []
if config.verbose:
print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
print('Loss RPN classifier: {}'.format(loss_rpn_cls))
print('Loss RPN regression: {}'.format(loss_rpn_regr))
print('Loss Detector classifier: {}'.format(loss_class_cls))
print('Loss Detector regression: {}'.format(loss_class_regr))
print('Elapsed time: {}'.format(time.time() - start_time))
curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
iter_num = 0
start_time = time.time()
write_log(callback,
['Elapsed_time', 'mean_overlapping_bboxes', 'mean_rpn_cls_loss', 'mean_rpn_reg_loss',
'mean_detection_cls_loss', 'mean_detection_reg_loss', 'mean_detection_acc', 'total_loss'],
[time.time() - start_time, mean_overlapping_bboxes, loss_rpn_cls, loss_rpn_regr,
loss_class_cls, loss_class_regr, class_acc, curr_loss],i)
if config.verbose:
print('The best loss is {}. The current loss is {}. Saving weights'.format(best_loss,curr_loss))
if curr_loss < best_loss:
best_loss = curr_loss
model_all.save_weights(log_dir+"/epoch{:03d}-loss{:.3f}-rpn{:.3f}-roi{:.3f}".format(i,curr_loss,loss_rpn_cls+loss_rpn_regr,loss_class_cls+loss_class_regr)+".h5")
break