-
Notifications
You must be signed in to change notification settings - Fork 470
/
l1_loss.py
41 lines (35 loc) · 1.33 KB
/
l1_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
class MaskL1Loss(nn.Module):
def __init__(self):
super(MaskL1Loss, self).__init__()
def forward(self, pred: torch.Tensor, gt, mask):
mask_sum = mask.sum()
if mask_sum.item() == 0:
return mask_sum, dict(l1_loss=mask_sum)
else:
loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum
return loss, dict(l1_loss=loss)
class BalanceL1Loss(nn.Module):
def __init__(self, negative_ratio=3.):
super(BalanceL1Loss, self).__init__()
self.negative_ratio = negative_ratio
def forward(self, pred: torch.Tensor, gt, mask):
'''
Args:
pred: (N, 1, H, W).
gt: (N, H, W).
mask: (N, H, W).
'''
loss = torch.abs(pred[:, 0] - gt)
positive = loss * mask
negative = loss * (1 - mask)
positive_count = int(mask.sum())
negative_count = min(
int((1 - mask).sum()),
int(positive_count * self.negative_ratio))
negative_loss, _ = torch.topk(negative.view(-1), negative_count)
negative_loss = negative_loss.sum() / negative_count
positive_loss = positive.sum() / positive_count
return positive_loss + negative_loss,\
dict(l1_loss=positive_loss, nge_l1_loss=negative_loss)