@@ -11,23 +11,14 @@ def __init__(self):
1111 super (SegmentationModuleBase , self ).__init__ ()
1212
1313 @staticmethod
14- def pixel_acc (pred , label , nr_layers = 4 ):
15- acc_sum , pixel_sum = 0 , 0
16- for i in range (nr_layers ):
17- _ , preds = torch .max (pred [i ], dim = 1 )
18- valid = (label [i ] >= 0 ).long ()
19- acc_sum += torch .sum (valid * (preds == label [i ]).long ())
20- pixel_sum += torch .sum (valid )
14+ def pixel_acc (pred , label ):
15+ _ , preds = torch .max (pred , dim = 1 )
16+ valid = (label >= 0 ).long ()
17+ acc_sum = torch .sum (valid * (preds == label ).long ())
18+ pixel_sum = torch .sum (valid )
2119 acc = acc_sum .float () / (pixel_sum .float () + 1e-10 )
2220 return acc
2321
24- @staticmethod
25- def pixel_loss (pred , label , crit , nr_layers = 4 ):
26- loss = 0
27- for i in range (nr_layers ):
28- loss += crit (pred [i ], label [i ])
29- return loss
30-
3122
3223class SegmentationModule (SegmentationModuleBase ):
3324 def __init__ (self , net_enc , net_dec , crit , deep_sup_scale = None ):
@@ -45,12 +36,17 @@ def forward(self, feed_dict, *, segSize=None):
4536 else :
4637 pred = self .decoder (self .encoder (feed_dict ['img_data' ], return_feature_maps = True ))
4738
48- loss = self .pixel_loss (pred , feed_dict ['seg_label' ], self .crit , self .nr_layers )
39+ # all maps resize to batch size
40+ seg_label = feed_dict ['seg_label' ]
41+ seg_label = seg_label .view (- 1 , seg_label .size (2 ), seg_label .size (3 )) # (b, h, w)
42+ pred = torch .cat (pred , dim = 0 ).view (- 1 , pred [0 ].size (1 ), seg_label .size (1 ), seg_label .size (2 )) # (b, c, h, w)
43+
44+ loss = self .crit (pred , seg_label )
4945 if self .deep_sup_scale is not None :
5046 loss_deepsup = self .crit (pred_deepsup , feed_dict ['seg_label' ])
5147 loss = loss + loss_deepsup * self .deep_sup_scale
5248
53- acc = self .pixel_acc (pred , feed_dict [ ' seg_label' ] )
49+ acc = self .pixel_acc (pred , seg_label )
5450 return loss , acc
5551 else : # inference
5652 pred = self .decoder (self .encoder (feed_dict ['img_data' ], return_feature_maps = True ), segSize = segSize )
0 commit comments