Skip to content

Commit 4deab05

Browse files
author
Tete Xiao
committed
fix bugs
1 parent 915c893 commit 4deab05

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

models/models.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,14 @@ def __init__(self):
1111
super(SegmentationModuleBase, self).__init__()
1212

1313
@staticmethod
14-
def pixel_acc(pred, label, nr_layers=4):
15-
acc_sum, pixel_sum = 0, 0
16-
for i in range(nr_layers):
17-
_, preds = torch.max(pred[i], dim=1)
18-
valid = (label[i] >= 0).long()
19-
acc_sum += torch.sum(valid * (preds == label[i]).long())
20-
pixel_sum += torch.sum(valid)
14+
def pixel_acc(pred, label):
15+
_, preds = torch.max(pred, dim=1)
16+
valid = (label >= 0).long()
17+
acc_sum = torch.sum(valid * (preds == label).long())
18+
pixel_sum = torch.sum(valid)
2119
acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
2220
return acc
2321

24-
@staticmethod
25-
def pixel_loss(pred, label, crit, nr_layers=4):
26-
loss = 0
27-
for i in range(nr_layers):
28-
loss += crit(pred[i], label[i])
29-
return loss
30-
3122

3223
class SegmentationModule(SegmentationModuleBase):
3324
def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
@@ -45,12 +36,17 @@ def forward(self, feed_dict, *, segSize=None):
4536
else:
4637
pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
4738

48-
loss = self.pixel_loss(pred, feed_dict['seg_label'], self.crit, self.nr_layers)
39+
# all maps resize to batch size
40+
seg_label = feed_dict['seg_label']
41+
seg_label = seg_label.view(-1, seg_label.size(2), seg_label.size(3)) # (b, h, w)
42+
pred = torch.cat(pred, dim=0).view(-1, pred[0].size(1), seg_label.size(1), seg_label.size(2)) # (b, c, h, w)
43+
44+
loss = self.crit(pred, seg_label)
4945
if self.deep_sup_scale is not None:
5046
loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
5147
loss = loss + loss_deepsup * self.deep_sup_scale
5248

53-
acc = self.pixel_acc(pred, feed_dict['seg_label'])
49+
acc = self.pixel_acc(pred, seg_label)
5450
return loss, acc
5551
else: # inference
5652
pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)

0 commit comments

Comments
 (0)