Skip to content

Commit

Permalink
updating how losses are handled
Browse files Browse the repository at this point in the history
  • Loading branch information
achaiah authored and achaiah committed Aug 22, 2019
1 parent 5f50976 commit 51166d0
Showing 1 changed file with 59 additions and 28 deletions.
87 changes: 59 additions & 28 deletions pywick/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
N_CLASSES = 1


class StableBCELoss(torch.nn.modules.Module):
def __init__(self):
class StableBCELoss(nn.Module):
def __init__(self, **kwargs):
super(StableBCELoss, self).__init__()

def forward(self, input, target):
Expand Down Expand Up @@ -369,7 +369,7 @@ def dice_coeff_hard_np(y_true, y_pred):
# Source: https://github.com/doodledood/carvana-image-masking-challenge/blob/master/losses.py
# TODO Replace this with nn.BCEWithLogitsLoss??
class BCELoss2d(nn.Module):
def __init__(self, weight=None, size_average=True):
def __init__(self, weight=None, size_average=True, **kwargs):
super(BCELoss2d, self).__init__()
self.bce_loss = nn.BCELoss(weight, size_average)

Expand All @@ -381,7 +381,7 @@ def forward(self, logits, targets):


class SoftDiceLoss(nn.Module):
def __init__(self, smooth=1.0):
def __init__(self, smooth=1.0, **kwargs):
super(SoftDiceLoss, self).__init__()
self.smooth = smooth

Expand Down Expand Up @@ -424,7 +424,7 @@ def forward(self, logits, targets):


class ThresholdedL1Loss(nn.Module):
def __init__(self, threshold=0.5):
def __init__(self, threshold=0.5, **kwargs):
super(ThresholdedL1Loss, self).__init__()
self.threshold = threshold

Expand Down Expand Up @@ -458,7 +458,7 @@ class BCEDiceFocalLoss(nn.Module):
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
:param weights: (list(), default = [1,1,1]) Optional weighing (0.0-1.0) of the losses in order of [bce, dice, focal]
'''
def __init__(self, focal_param, weights=[1.0,1.0,1.0]):
def __init__(self, focal_param, weights=[1.0,1.0,1.0], **kwargs):
super(BCEDiceFocalLoss, self).__init__()
self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
self.dice = SoftDiceLoss()
Expand All @@ -481,7 +481,7 @@ def forward(self, logits, targets):


class WeightedBCELoss2d(nn.Module):
def __init__(self):
def __init__(self, **kwargs):
super(WeightedBCELoss2d, self).__init__()

def forward(self, logits, labels, weights):
Expand All @@ -494,7 +494,7 @@ def forward(self, logits, labels, weights):


class WeightedSoftDiceLoss(nn.Module):
def __init__(self):
def __init__(self, **kwargs):
super(WeightedSoftDiceLoss, self).__init__()

def forward(self, logits, labels, weights):
Expand All @@ -512,12 +512,17 @@ def forward(self, logits, labels, weights):


class BCEDicePenalizeBorderLoss(nn.Module):
def __init__(self, kernel_size=21):
def __init__(self, kernel_size=21, **kwargs):
super(BCEDicePenalizeBorderLoss, self).__init__()
self.bce = WeightedBCELoss2d()
self.dice = WeightedSoftDiceLoss()
self.kernel_size = kernel_size

def to(self, device):
super().to(device=device)
self.bce.to(device=device)
self.dice.to(device=device)

def forward(self, logits, labels):
a = F.avg_pool2d(labels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
ind = a.ge(0.01) * a.le(0.99)
Expand Down Expand Up @@ -690,7 +695,7 @@ class BinaryFocalLoss(nn.Module):
gamma = 0 is equivalent to BinaryCrossEntropy Loss
'''
def __init__(self, gamma=1.333, eps=1e-6, alpha=1.0):
def __init__(self, gamma=1.333, eps=1e-6, alpha=1.0, **kwargs):
super().__init__()
self.gamma = gamma
self.eps = eps
Expand All @@ -707,7 +712,7 @@ def forward(self, inputs, targets):
# Source: https://github.com/atlab/attorch/blob/master/attorch/losses.py
# License: MIT
class PoissonLoss(nn.Module):
def __init__(self, bias=1e-12):
def __init__(self, bias=1e-12, **kwargs):
super().__init__()
self.bias = bias

Expand All @@ -718,7 +723,7 @@ def forward(self, output, target):


class PoissonLoss3d(nn.Module):
def __init__(self, bias=1e-12):
def __init__(self, bias=1e-12, **kwargs):
super().__init__()
self.bias = bias

Expand All @@ -730,7 +735,7 @@ def forward(self, output, target):


class L1Loss3d(nn.Module):
def __init__(self, bias=1e-12):
def __init__(self, bias=1e-12, **kwargs):
super().__init__()
self.bias = bias

Expand All @@ -757,7 +762,7 @@ class BCEWithLogitsViewLoss(nn.BCEWithLogitsLoss):
'''
Silly wrapper of nn.BCEWithLogitsLoss because BCEWithLogitsLoss only takes a 1-D array
'''
def __init__(self, weight=None, size_average=True):
def __init__(self, weight=None, size_average=True, **kwargs):
super().__init__(weight=weight, size_average=size_average)

def forward(self, input, target):
Expand Down Expand Up @@ -844,9 +849,9 @@ def soft_multiclass_dice_loss(y_true, y_pred, epsilon=1e-6):


class mIoULoss(nn.Module):
def __init__(self, weight=None, size_average=True, n_classes=2):
def __init__(self, weight=None, size_average=True, num_classes=2, **kwargs):
super(mIoULoss, self).__init__()
self.classes = n_classes
self.classes = num_classes

def forward(self, inputs, target_oneHot):
# inputs => N x Classes x H x W
Expand Down Expand Up @@ -876,20 +881,20 @@ def forward(self, inputs, target_oneHot):
# ====================== #
# Source: https://github.com/snakers4/mnasnet-pytorch/blob/master/src/models/semseg_loss.py
# Combination Loss from BCE and Dice
class ComboSemsegLoss(nn.Module):
class ComboBCEDiceLoss(nn.Module):
"""
Combination BinaryCrossEntropy (BCE) and Dice Loss with an optional running mean and loss weighing.
"""

def __init__(self, use_running_mean=False, bce_weight=1, dice_weight=1, eps=1e-6, gamma=0.9, combined_loss_only=False):
def __init__(self, use_running_mean=False, bce_weight=1, dice_weight=1, eps=1e-6, gamma=0.9, combined_loss_only=True, **kwargs):
"""
:param use_running_mean: - bool (default: False) Whether to accumulate a running mean and add it to the loss with (1-gamma)
:param bce_weight: - float (default: 1.0) Weight multiplier for the BCE loss (relative to dice)
:param dice_weight: - float (default: 1.0) Weight multiplier for the Dice loss (relative to BCE)
:param eps: -
:param gamma:
:param combined_loss_only:
:param combined_loss_only: - bool (default: True) whether to return a single combined loss or three separate losses
"""

super().__init__()
Expand All @@ -914,6 +919,10 @@ def __init__(self, use_running_mean=False, bce_weight=1, dice_weight=1, eps=1e-6
self.register_buffer('running_dice_loss', torch.zeros(1))
self.reset_parameters()

def to(self, device):
super().to(device=device)
self.bce_logits_loss.to(device=device)

def reset_parameters(self):
self.running_bce_loss.zero_()
self.running_dice_loss.zero_()
Expand Down Expand Up @@ -965,7 +974,8 @@ def __init__(self,
eps=1e-6,
gamma=0.9,
use_weight_mask=False,
combined_loss_only=False
combined_loss_only=False,
**kwargs
):
super().__init__()

Expand All @@ -987,6 +997,10 @@ def __init__(self,
self.register_buffer('running_dice_loss', torch.zeros(1))
self.reset_parameters()

def to(self, device):
super().to(device=device)
self.nll_loss.to(device=device)

def reset_parameters(self):
self.running_bce_loss.zero_()
self.running_dice_loss.zero_()
Expand Down Expand Up @@ -1054,7 +1068,7 @@ class OhemCrossEntropy2d(nn.Module):
OHEM description: http://www.erogol.com/online-hard-example-mining-pytorch/
"""
def __init__(self, ignore_label=-1, thresh=0.7, min_kept=100000, use_weight=True):
def __init__(self, ignore_label=-1, thresh=0.7, min_kept=100000, use_weight=True, **kwargs):
super(OhemCrossEntropy2d, self).__init__()
self.ignore_label = ignore_label
self.thresh = float(thresh)
Expand All @@ -1069,6 +1083,10 @@ def __init__(self, ignore_label=-1, thresh=0.7, min_kept=100000, use_weight=True
print("w/o class balance")
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)

def to(self, device):
super().to(device=device)
self.criterion.to(device=device)

def forward(self, predict, target, weight=None):
"""
Args:
Expand Down Expand Up @@ -1189,6 +1207,10 @@ def __init__(self, aux=False, aux_weight=0.4, weight=None, ignore_index=-1, **kw
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)

def to(self, device):
super().to(device=device)
self.bceloss.to(device=device)

def _aux_forward(self, *inputs, **kwargs):
*preds, target = tuple(inputs)

Expand Down Expand Up @@ -1226,6 +1248,10 @@ def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)

def to(self, device):
super().to(device=device)
self.bceloss.to(device=device)

def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(OHEMSegmentationLosses, self).forward(*inputs)
Expand Down Expand Up @@ -1435,7 +1461,7 @@ class FocalBinaryTverskyLoss(MultiTverskyLoss):
add focal index -> loss=(1-T_index)**(1/gamma)
"""

def __init__(self, alpha=0.5, beta=0.7, gamma=1.0, reduction='mean'):
def __init__(self, alpha=0.5, beta=0.7, gamma=1.0, reduction='mean', **kwargs):
"""
:param alpha (Tensor, float, optional): controls the penalty for false positives.
:param beta (Tensor, float, optional): controls the penalty for false negative.
Expand Down Expand Up @@ -1464,7 +1490,7 @@ def lovasz_grad(gt_sorted):


class LovaszSoftmax(nn.Module):
def __init__(self, reduction='mean'):
def __init__(self, reduction='mean', **kwargs):
super(LovaszSoftmax, self).__init__()
self.reduction = reduction

Expand Down Expand Up @@ -1521,13 +1547,15 @@ class ActiveContourLoss(nn.Module):
Params:
:param len_w: (float, default=1.0) - The multiplier to use when adding boundary loss.
:param reg_w: (float, default=1.0) - The multiplier to use when adding region loss.
:param apply_log: (bool, default=True) - Whether to transform the log into log space (due to the
"""

def __init__(self, len_w=1.0, reg_w=1.0):
def __init__(self, len_w=1., reg_w=1., apply_log=True, **kwargs):
super(ActiveContourLoss, self).__init__()
self.len_w = len_w
self.reg_w = reg_w
self.epsilon = 1e-8 # a parameter to avoid square root = zero issues
self.apply_log = apply_log

def forward(self, logits, target):
image_size = logits.size(3)
Expand Down Expand Up @@ -1583,9 +1611,12 @@ def forward(self, logits, target):
probs_diff = (probs[:, 0, :, :] - target[:, 0, :, :]).abs() # subtract mask from probs giving us the errors
error_out = (probs_diff * target[:, 0, :, :]) # multiply mask by error, giving us the error terms inside the mask.

loss = self.len_w * length_loss + self.reg_w * (error_in.sum() + error_out.sum())
if self.apply_log:
loss = torch.log(length_loss) + torch.log(error_in.sum() + error_out.sum())
else:
loss = self.len_w * length_loss + self.reg_w * (error_in.sum() + error_out.sum())

return loss
return torch.clamp(loss, min=0.0) # make sure we don't return negative values


# ===================== #
Expand Down Expand Up @@ -1712,7 +1743,7 @@ def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):


class BDLoss(nn.Module):
def __init__(self):
def __init__(self, **kwargs):
"""
compute boudary loss
only compute the loss of foreground
Expand Down Expand Up @@ -1764,7 +1795,7 @@ class TverskyLoss(nn.Module):
[1]: https://arxiv.org/abs/1706.05721
"""

def __init__(self, alpha, beta, eps=1e-7):
def __init__(self, alpha, beta, eps=1e-7, **kwargs):
super(TverskyLoss, self).__init__()

self.alpha = alpha
Expand Down

0 comments on commit 51166d0

Please sign in to comment.