Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit 6c4821b

Browse files
committedMar 19, 2020
perf(svm): 测试remain数据集时禁止梯度计算
1 parent 622b6ae commit 6c4821b

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed
 

‎py/linear_svm.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -182,39 +182,40 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
182182
print('remian_negative_list: %d' % (len(remain_negative_list)))
183183
# 如果剩余的负样本集小于96个,那么结束hard negative mining
184184
if len(remain_negative_list) > batch_negative:
185-
remain_dataset = CustomHardNegativeMiningDataset(remain_negative_list, jpeg_images, transform=transform)
186-
remain_data_loader = DataLoader(remain_dataset, batch_size=batch_total, num_workers=8, drop_last=True)
185+
with torch.set_grad_enabled(False):
186+
remain_dataset = CustomHardNegativeMiningDataset(remain_negative_list, jpeg_images, transform=transform)
187+
remain_data_loader = DataLoader(remain_dataset, batch_size=batch_total, num_workers=8, drop_last=True)
187188

188-
# 获取训练数据集的负样本集
189-
negative_list = train_dataset.get_negatives()
190-
res_negative_list = list()
191-
# Iterate over data.
192-
for inputs, labels, cache_dicts in remain_data_loader:
193-
inputs = inputs.to(device)
194-
labels = labels.to(device)
189+
# 获取训练数据集的负样本集
190+
negative_list = train_dataset.get_negatives()
191+
res_negative_list = list()
192+
# Iterate over data.
193+
for inputs, labels, cache_dicts in remain_data_loader:
194+
inputs = inputs.to(device)
195+
labels = labels.to(device)
195196

196-
# zero the parameter gradients
197-
optimizer.zero_grad()
197+
# zero the parameter gradients
198+
optimizer.zero_grad()
199+
200+
outputs = model(inputs)
201+
# print(outputs.shape)
202+
_, preds = torch.max(outputs, 1)
198203

199-
outputs = model(inputs)
200-
# print(outputs.shape)
201-
_, preds = torch.max(outputs, 1)
202-
203-
hard_negative_list, easy_neagtive_list = add_hard_negatives(preds.cpu().numpy(), cache_dicts)
204-
205-
negative_list.extend(hard_negative_list)
206-
res_negative_list.extend(easy_neagtive_list)
207-
208-
# 训练完成后,重置负样本,进行hard negatives mining
209-
train_dataset.set_negative_list(negative_list)
210-
tmp_sampler = CustomBatchSampler(train_dataset.get_positive_num(), train_dataset.get_negative_num(),
211-
batch_positive, batch_negative)
212-
data_loaders['train'] = DataLoader(train_dataset, batch_size=batch_total, sampler=tmp_sampler,
213-
num_workers=8, drop_last=True)
214-
# 重置数据集大小
215-
data_sizes['train'] = len(tmp_sampler)
216-
# 保存剩余的负样本集
217-
data_loaders['remain'] = res_negative_list
204+
hard_negative_list, easy_neagtive_list = add_hard_negatives(preds.cpu().numpy(), cache_dicts)
205+
206+
negative_list.extend(hard_negative_list)
207+
res_negative_list.extend(easy_neagtive_list)
208+
209+
# 训练完成后,重置负样本,进行hard negatives mining
210+
train_dataset.set_negative_list(negative_list)
211+
tmp_sampler = CustomBatchSampler(train_dataset.get_positive_num(), train_dataset.get_negative_num(),
212+
batch_positive, batch_negative)
213+
data_loaders['train'] = DataLoader(train_dataset, batch_size=batch_total, sampler=tmp_sampler,
214+
num_workers=8, drop_last=True)
215+
# 重置数据集大小
216+
data_sizes['train'] = len(tmp_sampler)
217+
# 保存剩余的负样本集
218+
data_loaders['remain'] = res_negative_list
218219

219220
# 每训练一轮就保存
220221
save_model(model, 'models/linear_svm_alexnet_car_%d.pth' % epoch)

0 commit comments

Comments
 (0)
Failed to load comments.