diff --git a/loss.py b/loss.py index 559f7b0..7d0a7be 100644 --- a/loss.py +++ b/loss.py @@ -1,8 +1,3 @@ -''' -Reference : - https://github.com/kuangliu/pytorch-retinanet/blob/master/loss.py -''' - from __future__ import print_function import torch @@ -33,7 +28,6 @@ def focal_loss2d(self, x, y): t = one_hot_embedding(y.data.cpu(), 2) else: t = one_hot_embedding(y.data, 2) - # t = t[:,1:] # exclude background if self.using_gpu is True: t = Variable(t).cuda() @@ -135,11 +129,9 @@ def forward(self, loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, ma masked_cls_preds = cls_preds[mask].view(-1,self.num_classes) cls_loss = self.focal_loss(masked_cls_preds, cls_targets[pos_neg]) - # mask_loss = self.ce_loss(mask_preds, mask_targets) - mask_loss = self.focal_loss2d(mask_preds, mask_targets) + mask_loss = self.ce_loss(mask_preds, mask_targets) + # mask_loss = self.focal_loss2d(mask_preds, mask_targets) - print('loc_loss: %.3f | cls_loss: %.3f | mask_loss: %.3f' % - (loc_loss.data[0]/num_pos, cls_loss.data[0]/num_pos, mask_loss.data[0]), end=' | ') - loss = ((loc_loss+cls_loss)/num_pos) + (mask_loss) - return loss, num_pos + # loss = ((loc_loss+cls_loss)/num_pos) + (mask_loss) + return loc_loss, cls_loss, mask_loss, num_pos diff --git a/train.py b/train.py index 2ecbb00..a13448c 100644 --- a/train.py +++ b/train.py @@ -50,7 +50,7 @@ validset = ListDataset(root="../valid", gt_extension=".txt", labelmap_path="class_label_map.xlsx", is_train=False, transform=transform, input_image_size=512, - num_crops=1, original_img_size=512) + num_crops=5, original_img_size=512) validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=validset.collate_fn) print("lr : " + str(lr)) @@ -130,14 +130,17 @@ def train(epoch): optimizer.zero_grad() loc_preds, cls_preds, mask_preds = net(inputs) - loss, num_matched_anchors = criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets) + loc_loss, cls_loss, mask_loss, num_matched_anchors = \ + criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets) + loss = ((loc_loss + cls_loss) / num_matched_anchors) + mask_loss loss.backward() optimizer.step() train_loss += loss.data[0] avg_matched_anchor += float(num_matched_anchors) - print('epoch: %3d | iter: %4d | train_loss: %.3f | avg_loss: %.3f | avg_num. matched: %d' - % (epoch, batch_idx, loss.data[0], train_loss/(batch_idx+1), avg_matched_anchor/(batch_idx+1))) + print('epoch: %3d | iter: %4d | loc_loss: %.3f | cls_loss: %.3f | mask_loss: %.3f | train_loss: %.3f | avg_loss: %.3f | avg_num. matched: %d' + % (epoch, batch_idx, loc_loss.data[0] / num_matched_anchors, cls_loss.data[0] / num_matched_anchors, + mask_loss.data[0], loss.data[0], train_loss / (batch_idx + 1), avg_matched_anchor / (batch_idx + 1))) # Test def valid(epoch): @@ -158,11 +161,17 @@ def valid(epoch): mask_targets = Variable(mask_targets) loc_preds, cls_preds, mask_preds = net(inputs) - loss, _ = criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets) + loc_loss, cls_loss, mask_loss, num_matched_anchors = \ + criterion(loc_preds, loc_targets, cls_preds, cls_targets, mask_preds, mask_targets) + loss = ((loc_loss + cls_loss) / num_matched_anchors) + mask_loss valid_loss += loss.data[0] - print('valid_loss: %.3f | avg_loss: %.3f' % (loss.data[0], valid_loss/(batch_idx+1))) + print('loc_loss: %.3f | cls_loss: %.3f | valid_loss: %.3f | avg_loss: %.3f' + % (loc_loss.data[0] / num_matched_anchors, cls_loss.data[0] / num_matched_anchors, + loss.data[0], valid_loss / (batch_idx + 1))) # Save checkpoint + # Every checkpoints are stored to analyze how is going training + # Model is selected by low-validation error. valid_loss /= len(validloader) print('Saving..') state = {