diff --git a/README.md b/README.md
index 482cf5a..bb091b7 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,13 @@ Hey, [check this out](https://pywick.readthedocs.io/en/latest/), we now
have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a
work in progress though so apologies for anything that's broken.
+## What's New (highlights)
+- **Aug. 1, 2019**
+ - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet
+ - New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses
+ - Major restructuring and standardization of NN models and loading functionality
+ - General bug fixes and code improvements
+
## Install
`pip install pywick`
diff --git a/docs/source/README.md b/docs/source/README.md
index 482cf5a..bb091b7 100644
--- a/docs/source/README.md
+++ b/docs/source/README.md
@@ -28,6 +28,13 @@ Hey, [check this out](https://pywick.readthedocs.io/en/latest/), we now
have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a
work in progress though so apologies for anything that's broken.
+## What's New (highlights)
+- **Aug. 1, 2019**
+ - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet
+ - New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses
+ - Major restructuring and standardization of NN models and loading functionality
+ - General bug fixes and code improvements
+
## Install
`pip install pywick`
diff --git a/pywick/losses.py b/pywick/losses.py
index 2dfde1e..aca0223 100644
--- a/pywick/losses.py
+++ b/pywick/losses.py
@@ -12,6 +12,14 @@
Make sure to read the documentation and notes (in the code) for each loss to understand how it is applied.
`Read this blog post `_
+
+Note:
+ ```
+ Logit is the vector of raw (non-normalized) predictions that a classification model generates, which is ordinarily then passed to a normalization function.
+ If the model is solving a multi-class classification problem, logits typically become an input to the softmax function. The softmax function then generates
+ a vector of (normalized) probabilities with one value for each possible class.
+ ```
+For example, BCEWithLogitsLoss is a BCE that accepts R((-inf, inf)) and automatically applies torch.sigmoid to convert it to ([0,1]) space.
"""
## Various loss calculation functions ##
@@ -280,9 +288,11 @@ def lovasz_single(logit, label, prox=False, max_steps=20, debug={}):
loss = lovasz_binary(margins, target, prox, max_steps, debug=debug)
return loss
-# WARNING THIS IS VERY SLOW FOR SOME REASON!!
-def dice_coefficient(logit, label, isCuda = True):
+
+def dice_coefficient(logit, label, isCuda=True):
'''
+ WARNING THIS IS VERY SLOW FOR SOME REASON!!
+
:param logit: calculated guess (expects torch.Tensor)
:param label: truth label (expects torch.Tensor)
:return: dice coefficient
@@ -368,8 +378,9 @@ def forward(self, logits, targets):
class SoftDiceLoss(nn.Module):
- def __init__(self, weight=None, size_average=True):
+ def __init__(self, smooth=1.0):
super(SoftDiceLoss, self).__init__()
+ self.smooth = smooth
def forward(self, logits, targets):
#print('logits: {}, targets: {}'.format(logits.size(), targets.size()))
@@ -379,9 +390,9 @@ def forward(self, logits, targets):
m2 = targets.view(num, -1)
intersection = (m1 * m2)
- smooth = 1.
+ # smooth = 1.
- score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
+ score = 2. * (intersection.sum(1) + self.smooth) / (m1.sum(1) + m2.sum(1) + self.smooth)
score = 1 - score.sum() / num
return score
@@ -428,7 +439,7 @@ def forward(self, logits, targets):
class BCEDiceTL1Loss(nn.Module):
def __init__(self, threshold=0.5):
super(BCEDiceTL1Loss, self).__init__()
- self.bce = BCELoss2d()
+ self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
self.dice = SoftDiceLoss()
self.tl1 = ThresholdedL1Loss(threshold=threshold)
@@ -438,19 +449,22 @@ def forward(self, logits, targets):
class BCEDiceFocalLoss(nn.Module):
'''
- :param l: l-parameter for FocalLoss
- :param weight_of_focal: How to weigh the focal loss (between 0 - 1)
+ :param num_classes: number of classes
+ :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
+ focus on hard misclassified example
+ :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
+ :param weights: (list(), default = [1,1,1]) Optional weighing (0.0-1.0) of the losses in order of [bce, dice, focal]
'''
- def __init__(self, l=0.5, weight_of_focal=1.):
+ def __init__(self, focal_param, weights=[1.0,1.0,1.0]):
super(BCEDiceFocalLoss, self).__init__()
- # self.bce = BCELoss2d()
- # self.dice = SoftDiceLoss()
- self.dice = BCELoss2d()
- self.focal = FocalLoss(l=l)
- self.weight_of_focal = weight_of_focal
+ self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
+ self.dice = SoftDiceLoss()
+ self.focal = FocalLoss(l=focal_param)
+ self.weights = weights
def forward(self, logits, targets):
- return self.dice(logits, targets) + self.weight_of_focal * self.focal(logits, targets)
+ logits = logits.squeeze()
+ return self.weights[0] * self.bce(logits, targets) + self.weights[1] * self.dice(logits, targets) + self.weights[2] * self.focal(logits.unsqueeze(1), targets.unsqueeze(1))
class BCEDiceLoss(nn.Module):
@@ -469,8 +483,8 @@ def __init__(self):
def forward(self, logits, labels, weights):
w = weights.view(-1)
- z = logits.view (-1)
- t = labels.view (-1)
+ z = logits.view(-1)
+ t = labels.view(-1)
loss = w*z.clamp(min=0) - w*z*t + w*torch.log(1 + torch.exp(-z.abs()))
loss = loss.sum()/w.sum()
return loss
@@ -577,20 +591,16 @@ def forward(self, logit, target):
# alpha = alpha * (1 - self.alpha)
# alpha = alpha.scatter_(1, target.long(), self.alpha)
epsilon = 1e-10
- alpha = self.alpha
- if alpha.device != input.device:
- alpha = alpha.to(input.device)
+ alpha = self.alpha.to(logit.device)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
- if one_hot_key.device != logit.device:
- one_hot_key = one_hot_key.to(logit.device)
+ one_hot_key = one_hot_key.to(logit.device)
if self.smooth:
- one_hot_key = torch.clamp(
- one_hot_key, self.smooth / (self.num_class - 1), 1.0 - self.smooth)
+ one_hot_key = torch.clamp(one_hot_key, self.smooth / (self.num_class - 1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + epsilon
logpt = pt.log()
@@ -636,8 +646,16 @@ def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
def forward(self, inputs, targets): # variables
P = F.softmax(inputs)
- b,c,h,w = inputs.size()
- class_mask = Variable(torch.zeros([b,c+1,h,w]).cuda())
+ if len(inputs.size()) == 3:
+ torch_out = torch.zeros(inputs.size())
+ else:
+ b,c,h,w = inputs.size()
+ torch_out = torch.zeros([b,c+1,h,w])
+
+ if inputs.is_cuda:
+ torch_out = torch_out.cuda()
+
+ class_mask = Variable(torch_out)
class_mask.scatter_(1, targets.long(), 1.)
class_mask = class_mask[:,:-1,:,:]
@@ -659,30 +677,23 @@ def forward(self, inputs, targets): # variables
# -------- #
# -------- #
-# Source: https://becominghuman.ai/investigating-focal-and-dice-loss-for-the-kaggle-2018-data-science-bowl-65fb9af4f36c
-class BinaryFocalLoss3(nn.Module):
+# Source: https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/4
+class BinaryFocalLoss(nn.Module):
'''
+ Implementation of binary focal loss. For multi-class focal loss use one of the other implementations.
+
gamma = 0 is equivalent to BinaryCrossEntropy Loss
'''
- def __init__(self, gamma=0.5):
+ def __init__(self, gamma=1.333, eps=1e-6, alpha=1.0):
super().__init__()
self.gamma = gamma
+ self.eps = eps
- def forward(self, input, target):
- input = input.squeeze()
- target = target.squeeze()
- # Inspired by the implementation of binary_cross_entropy_with_logits
- if not (target.size() == input.size()):
- raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
-
- max_val = (-input).clamp(min=0)
- loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
-
- # This formula gives us the log sigmoid of 1-p if y is 0 and of p if y is 1
- invprobs = F.logsigmoid(-input * (target * 2 - 1))
- loss = (invprobs * self.gamma).exp() * loss
-
- return loss.mean()
+ def forward(self, inputs, targets):
+ BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
+ pt = torch.exp(-BCE_loss) # prevents nans when probability 0
+ F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
+ return F_loss.mean()
# -------- #
# ==== Additional Losses === #
@@ -863,7 +874,11 @@ class ComboSemsegLoss(nn.Module):
def __init__(self, use_running_mean=False, bce_weight=1, dice_weight=1, eps=1e-10, gamma=0.9, combined_loss_only=False):
super().__init__()
- self.nll_loss = nn.BCEWithLogitsLoss()
+ '''
+ Note: BCEWithLogitsLoss already performs a torch.sigmoid(pred)
+ before applying BCE!
+ '''
+ self.bce_logits_loss = nn.BCEWithLogitsLoss()
self.dice_weight = dice_weight
self.bce_weight = bce_weight
@@ -885,14 +900,15 @@ def reset_parameters(self):
self.running_dice_loss.zero_()
def forward(self, outputs, targets):
- # inputs and targets are assumed to be BxCxWxH
+ # inputs and targets are assumed to be BxCxWxH (batch, color, width, height)
+ outputs = outputs.squeeze() # necessary in case we're dealing with binary segmentation (color dim of 1)
assert len(outputs.shape) == len(targets.shape)
# assert that B, W and H are the same
assert outputs.size(-0) == targets.size(-0)
assert outputs.size(-1) == targets.size(-1)
assert outputs.size(-2) == targets.size(-2)
- bce_loss = self.nll_loss(outputs, targets)
+ bce_loss = self.bce_logits_loss(outputs, targets)
dice_target = (targets == 1).float()
dice_output = F.sigmoid(outputs)
@@ -1228,7 +1244,7 @@ def _get_batch_label_vector(target, nclass):
# Source: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/TverskyLoss/binarytverskyloss.py (MIT)
-class FocalBinaryTverskyLoss(Function):
+class FocalBinaryTverskyFunc(Function):
"""
Focal Tversky Loss as defined in `this paper `_
@@ -1344,7 +1360,7 @@ class MultiTverskyLoss(nn.Module):
:param weights (Tensor, optional): a manual rescaling weight given to each class. If given, it has to be a Tensor of size `C`
"""
- def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, weights=None):
+ def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean', weights=None):
"""
:param alpha (Tensor, float, optional): controls the penalty for false positives.
:param beta (Tensor, float, optional): controls the penalty for false negative.
@@ -1356,10 +1372,10 @@ def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, weights=None):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
+ self.reduction = reduction
self.weights = weights
def forward(self, inputs, targets):
-
num_class = inputs.size(1)
weight_losses = 0.0
if self.weights is not None:
@@ -1372,10 +1388,41 @@ def forward(self, inputs, targets):
input_idx = input_slices[idx]
input_idx = torch.cat((1 - input_idx, input_idx), dim=1)
target_idx = (targets == idx) * 1
- loss_func = FocalBinaryTverskyLoss(self.alpha, self.beta, self.gamma)
+ loss_func = FocalBinaryTverskyFunc(self.alpha, self.beta, self.gamma, self.reduction)
loss_idx = loss_func(input_idx, target_idx)
weight_losses+=loss_idx * weights[idx]
# loss = torch.Tensor(weight_losses)
# loss = loss.to(inputs.device)
# loss = torch.sum(loss)
return weight_losses
+
+
+class FocalBinaryTverskyLoss(MultiTverskyLoss):
+ """
+ Binary version of Focal Tversky Loss as defined in `this paper `_
+
+ `Authors' implementation `_ in Keras.
+
+ Params:
+ :param alpha: controls the penalty for false positives.
+ :param beta: penalty for false negative.
+ :param gamma : focal coefficient range[1,3]
+ :param reduction: return mode
+
+ Notes:
+ alpha = beta = 0.5 => dice coeff
+ alpha = beta = 1 => tanimoto coeff
+ alpha + beta = 1 => F beta coeff
+ add focal index -> loss=(1-T_index)**(1/gamma)
+ """
+
+ def __init__(self, alpha=0.5, beta=0.7, gamma=1.33333, reduction='mean'):
+ """
+ :param alpha (Tensor, float, optional): controls the penalty for false positives.
+ :param beta (Tensor, float, optional): controls the penalty for false negative.
+ :param gamma (Tensor, float, optional): focal coefficient
+ """
+ super().__init__(alpha, beta, gamma, reduction)
+
+ def forward(self, inputs, targets):
+ return super().forward(inputs, targets.unsqueeze(1))
diff --git a/pywick/modules/module_trainer.py b/pywick/modules/module_trainer.py
index 2f692ef..61b43eb 100644
--- a/pywick/modules/module_trainer.py
+++ b/pywick/modules/module_trainer.py
@@ -338,6 +338,7 @@ def fit(self,
self._optimizer.zero_grad()
output_batch = fit_forward_fn(input_batch)
loss = fit_loss_fn(output_batch, target_batch)
+ assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'
loss.backward()
self._optimizer.step()
# ---------------------------------------------
@@ -445,7 +446,9 @@ def fit_loader(self,
# ---------------------------------------------
self._optimizer.zero_grad()
output_batch = fit_forward_fn(input_batch)
+
loss = fit_loss_fn(output_batch, target_batch)
+ assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'
loss.backward()
self._optimizer.step()
# ---------------------------------------------
@@ -601,6 +604,7 @@ def evaluate(self,
self._optimizer.zero_grad()
output_batch = eval_forward_fn(input_batch)
loss = eval_loss_fn(output_batch, target_batch)
+ assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'
if conditions_container:
cond_logs = conditions_container(CondType.POST, epoch_num=None, batch_num=batch_idx, net=self.model, input_batch=input_batch, output_batch=output_batch, target_batch=target_batch)
@@ -657,6 +661,7 @@ def evaluate_loader(self, loader, eval_helper_name=None, verbose=1):
self._optimizer.zero_grad()
output_batch = eval_forward_fn(input_batch)
loss = eval_loss_fn(output_batch, target_batch)
+ assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'
if conditions_container:
cond_logs = conditions_container(CondType.POST, epoch_num=None, batch_num=batch_idx, net=self.model, input_batch=input_batch, output_batch=output_batch, target_batch=target_batch)