Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit 4b46801

Browse files
committedApr 4, 2020
perf(bbox_regression): 使用Adam;调整targets精度为float32
1 parent 31b7401 commit 4b46801

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed
 

‎py/bbox_regression.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def train_model(data_loader, feature_model, model, criterion, optimizer, lr_sche
4949

5050
# Iterate over data.
5151
for inputs, targets in data_loader:
52+
print(targets.dtype)
5253
inputs = inputs.to(device)
53-
targets = targets.to(device)
54+
targets = targets.float().to(device)
5455

5556
features = feature_model.features(inputs)
5657
features = torch.flatten(features, 1)
@@ -113,7 +114,8 @@ def get_model(device=None):
113114
model.to(device)
114115

115116
criterion = nn.MSELoss()
116-
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-3)
117+
# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-3)
118+
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
117119
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
118120

119121
loss_list = train_model(data_loader, feature_model, model, criterion, optimizer, lr_scheduler, device=device,

0 commit comments

Comments
 (0)
Failed to load comments.