Skip to content

Commit

Permalink
display loss in train function
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke.taek committed Feb 28, 2018
1 parent 30b1c47 commit 40f76ad
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
16 changes: 4 additions & 12 deletions loss.py
@@ -1,8 +1,3 @@
'''
Reference :
https://github.com/kuangliu/pytorch-retinanet/blob/master/loss.py
'''

from __future__ import print_function

import torch
Expand Down Expand Up @@ -33,7 +28,6 @@ def focal_loss2d(self, x, y):
t = one_hot_embedding(y.data.cpu(), 2)
else:
t = one_hot_embedding(y.data, 2)
# t = t[:,1:] # exclude background

if self.using_gpu is True:
t = Variable(t).cuda()
Expand Down Expand Up @@ -135,11 +129,9 @@ def forward(self, loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, ma
masked_cls_preds = cls_preds[mask].view(-1,self.num_classes)
cls_loss = self.focal_loss(masked_cls_preds, cls_targets[pos_neg])

# mask_loss = self.ce_loss(mask_preds, mask_targets)
mask_loss = self.focal_loss2d(mask_preds, mask_targets)
mask_loss = self.ce_loss(mask_preds, mask_targets)
# mask_loss = self.focal_loss2d(mask_preds, mask_targets)


print('loc_loss: %.3f | cls_loss: %.3f | mask_loss: %.3f' %
(loc_loss.data[0]/num_pos, cls_loss.data[0]/num_pos, mask_loss.data[0]), end=' | ')
loss = ((loc_loss+cls_loss)/num_pos) + (mask_loss)
return loss, num_pos
# loss = ((loc_loss+cls_loss)/num_pos) + (mask_loss)
return loc_loss, cls_loss, mask_loss, num_pos
21 changes: 15 additions & 6 deletions train.py
Expand Up @@ -50,7 +50,7 @@

validset = ListDataset(root="../valid", gt_extension=".txt",
labelmap_path="class_label_map.xlsx", is_train=False, transform=transform, input_image_size=512,
num_crops=1, original_img_size=512)
num_crops=5, original_img_size=512)
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=validset.collate_fn)

print("lr : " + str(lr))
Expand Down Expand Up @@ -130,14 +130,17 @@ def train(epoch):

optimizer.zero_grad()
loc_preds, cls_preds, mask_preds = net(inputs)
loss, num_matched_anchors = criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets)
loc_loss, cls_loss, mask_loss, num_matched_anchors = \
criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets)
loss = ((loc_loss + cls_loss) / num_matched_anchors) + mask_loss
loss.backward()
optimizer.step()

train_loss += loss.data[0]
avg_matched_anchor += float(num_matched_anchors)
print('epoch: %3d | iter: %4d | train_loss: %.3f | avg_loss: %.3f | avg_num. matched: %d'
% (epoch, batch_idx, loss.data[0], train_loss/(batch_idx+1), avg_matched_anchor/(batch_idx+1)))
print('epoch: %3d | iter: %4d | loc_loss: %.3f | cls_loss: %.3f | mask_loss: %.3f | train_loss: %.3f | avg_loss: %.3f | avg_num. matched: %d'
% (epoch, batch_idx, loc_loss.data[0] / num_matched_anchors, cls_loss.data[0] / num_matched_anchors,
mask_loss.data[0], loss.data[0], train_loss / (batch_idx + 1), avg_matched_anchor / (batch_idx + 1)))

# Test
def valid(epoch):
Expand All @@ -158,11 +161,17 @@ def valid(epoch):
mask_targets = Variable(mask_targets)

loc_preds, cls_preds, mask_preds = net(inputs)
loss, _ = criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets)
loc_loss, cls_loss, mask_loss, num_matched_anchors = \
criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets)
loss = ((loc_loss + cls_loss) / num_matched_anchors) + mask_loss
valid_loss += loss.data[0]
print('valid_loss: %.3f | avg_loss: %.3f' % (loss.data[0], valid_loss/(batch_idx+1)))
print('loc_loss: %.3f | cls_loss: %.3f | valid_loss: %.3f | avg_loss: %.3f'
% (loc_loss.data[0] / num_matched_anchors, cls_loss.data[0] / num_matched_anchors,
loss.data[0], valid_loss / (batch_idx + 1)))

# Save checkpoint
# Every checkpoints are stored to analyze how is going training
# Model is selected by low-validation error.
valid_loss /= len(validloader)
print('Saving..')
state = {
Expand Down

0 comments on commit 40f76ad

Please sign in to comment.