@@ -182,39 +182,40 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
182
182
print ('remian_negative_list: %d' % (len (remain_negative_list )))
183
183
# 如果剩余的负样本集小于96个,那么结束hard negative mining
184
184
if len (remain_negative_list ) > batch_negative :
185
- remain_dataset = CustomHardNegativeMiningDataset (remain_negative_list , jpeg_images , transform = transform )
186
- remain_data_loader = DataLoader (remain_dataset , batch_size = batch_total , num_workers = 8 , drop_last = True )
185
+ with torch .set_grad_enabled (False ):
186
+ remain_dataset = CustomHardNegativeMiningDataset (remain_negative_list , jpeg_images , transform = transform )
187
+ remain_data_loader = DataLoader (remain_dataset , batch_size = batch_total , num_workers = 8 , drop_last = True )
187
188
188
- # 获取训练数据集的负样本集
189
- negative_list = train_dataset .get_negatives ()
190
- res_negative_list = list ()
191
- # Iterate over data.
192
- for inputs , labels , cache_dicts in remain_data_loader :
193
- inputs = inputs .to (device )
194
- labels = labels .to (device )
189
+ # 获取训练数据集的负样本集
190
+ negative_list = train_dataset .get_negatives ()
191
+ res_negative_list = list ()
192
+ # Iterate over data.
193
+ for inputs , labels , cache_dicts in remain_data_loader :
194
+ inputs = inputs .to (device )
195
+ labels = labels .to (device )
195
196
196
- # zero the parameter gradients
197
- optimizer .zero_grad ()
197
+ # zero the parameter gradients
198
+ optimizer .zero_grad ()
199
+
200
+ outputs = model (inputs )
201
+ # print(outputs.shape)
202
+ _ , preds = torch .max (outputs , 1 )
198
203
199
- outputs = model (inputs )
200
- # print(outputs.shape)
201
- _ , preds = torch .max (outputs , 1 )
202
-
203
- hard_negative_list , easy_neagtive_list = add_hard_negatives (preds .cpu ().numpy (), cache_dicts )
204
-
205
- negative_list .extend (hard_negative_list )
206
- res_negative_list .extend (easy_neagtive_list )
207
-
208
- # 训练完成后,重置负样本,进行hard negatives mining
209
- train_dataset .set_negative_list (negative_list )
210
- tmp_sampler = CustomBatchSampler (train_dataset .get_positive_num (), train_dataset .get_negative_num (),
211
- batch_positive , batch_negative )
212
- data_loaders ['train' ] = DataLoader (train_dataset , batch_size = batch_total , sampler = tmp_sampler ,
213
- num_workers = 8 , drop_last = True )
214
- # 重置数据集大小
215
- data_sizes ['train' ] = len (tmp_sampler )
216
- # 保存剩余的负样本集
217
- data_loaders ['remain' ] = res_negative_list
204
+ hard_negative_list , easy_neagtive_list = add_hard_negatives (preds .cpu ().numpy (), cache_dicts )
205
+
206
+ negative_list .extend (hard_negative_list )
207
+ res_negative_list .extend (easy_neagtive_list )
208
+
209
+ # 训练完成后,重置负样本,进行hard negatives mining
210
+ train_dataset .set_negative_list (negative_list )
211
+ tmp_sampler = CustomBatchSampler (train_dataset .get_positive_num (), train_dataset .get_negative_num (),
212
+ batch_positive , batch_negative )
213
+ data_loaders ['train' ] = DataLoader (train_dataset , batch_size = batch_total , sampler = tmp_sampler ,
214
+ num_workers = 8 , drop_last = True )
215
+ # 重置数据集大小
216
+ data_sizes ['train' ] = len (tmp_sampler )
217
+ # 保存剩余的负样本集
218
+ data_loaders ['remain' ] = res_negative_list
218
219
219
220
# 每训练一轮就保存
220
221
save_model (model , 'models/linear_svm_alexnet_car_%d.pth' % epoch )
0 commit comments