Skip to content

Commit

Permalink
Updates and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
achaiah authored and achaiah committed Aug 1, 2019
1 parent 353f600 commit e79bab3
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 51 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ Hey, [check this out](https://pywick.readthedocs.io/en/latest/), we now
have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a
work in progress though so apologies for anything that's broken.

## What's New (highlights)
- **Aug. 1, 2019**
- New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet
- New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses
- Major restructuring and standardization of NN models and loading functionality
- General bug fixes and code improvements

## Install
`pip install pywick`

Expand Down
7 changes: 7 additions & 0 deletions docs/source/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ Hey, [check this out](https://pywick.readthedocs.io/en/latest/), we now
have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a
work in progress though so apologies for anything that's broken.

## What's New (highlights)
- **Aug. 1, 2019**
- New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet
- New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses
- Major restructuring and standardization of NN models and loading functionality
- General bug fixes and code improvements

## Install
`pip install pywick`

Expand Down
149 changes: 98 additions & 51 deletions pywick/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
Make sure to read the documentation and notes (in the code) for each loss to understand how it is applied.
`Read this blog post <https://gombru.github.io/2018/05/23/cross_entropy_loss/>`_
Note:
```
Logit is the vector of raw (non-normalized) predictions that a classification model generates, which is ordinarily then passed to a normalization function.
If the model is solving a multi-class classification problem, logits typically become an input to the softmax function. The softmax function then generates
a vector of (normalized) probabilities with one value for each possible class.
```
For example, BCEWithLogitsLoss is a BCE that accepts R((-inf, inf)) and automatically applies torch.sigmoid to convert it to ([0,1]) space.
"""

## Various loss calculation functions ##
Expand Down Expand Up @@ -280,9 +288,11 @@ def lovasz_single(logit, label, prox=False, max_steps=20, debug={}):
loss = lovasz_binary(margins, target, prox, max_steps, debug=debug)
return loss

# WARNING THIS IS VERY SLOW FOR SOME REASON!!
def dice_coefficient(logit, label, isCuda = True):

def dice_coefficient(logit, label, isCuda=True):
'''
WARNING THIS IS VERY SLOW FOR SOME REASON!!
:param logit: calculated guess (expects torch.Tensor)
:param label: truth label (expects torch.Tensor)
:return: dice coefficient
Expand Down Expand Up @@ -368,8 +378,9 @@ def forward(self, logits, targets):


class SoftDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
def __init__(self, smooth=1.0):
super(SoftDiceLoss, self).__init__()
self.smooth = smooth

def forward(self, logits, targets):
#print('logits: {}, targets: {}'.format(logits.size(), targets.size()))
Expand All @@ -379,9 +390,9 @@ def forward(self, logits, targets):
m2 = targets.view(num, -1)
intersection = (m1 * m2)

smooth = 1.
# smooth = 1.

score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
score = 2. * (intersection.sum(1) + self.smooth) / (m1.sum(1) + m2.sum(1) + self.smooth)
score = 1 - score.sum() / num
return score

Expand Down Expand Up @@ -428,7 +439,7 @@ def forward(self, logits, targets):
class BCEDiceTL1Loss(nn.Module):
def __init__(self, threshold=0.5):
super(BCEDiceTL1Loss, self).__init__()
self.bce = BCELoss2d()
self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
self.dice = SoftDiceLoss()
self.tl1 = ThresholdedL1Loss(threshold=threshold)

Expand All @@ -438,19 +449,22 @@ def forward(self, logits, targets):

class BCEDiceFocalLoss(nn.Module):
'''
:param l: l-parameter for FocalLoss
:param weight_of_focal: How to weigh the focal loss (between 0 - 1)
:param num_classes: number of classes
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
:param weights: (list(), default = [1,1,1]) Optional weighing (0.0-1.0) of the losses in order of [bce, dice, focal]
'''
def __init__(self, l=0.5, weight_of_focal=1.):
def __init__(self, focal_param, weights=[1.0,1.0,1.0]):
super(BCEDiceFocalLoss, self).__init__()
# self.bce = BCELoss2d()
# self.dice = SoftDiceLoss()
self.dice = BCELoss2d()
self.focal = FocalLoss(l=l)
self.weight_of_focal = weight_of_focal
self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
self.dice = SoftDiceLoss()
self.focal = FocalLoss(l=focal_param)
self.weights = weights

def forward(self, logits, targets):
return self.dice(logits, targets) + self.weight_of_focal * self.focal(logits, targets)
logits = logits.squeeze()
return self.weights[0] * self.bce(logits, targets) + self.weights[1] * self.dice(logits, targets) + self.weights[2] * self.focal(logits.unsqueeze(1), targets.unsqueeze(1))


class BCEDiceLoss(nn.Module):
Expand All @@ -469,8 +483,8 @@ def __init__(self):

def forward(self, logits, labels, weights):
w = weights.view(-1)
z = logits.view (-1)
t = labels.view (-1)
z = logits.view(-1)
t = labels.view(-1)
loss = w*z.clamp(min=0) - w*z*t + w*torch.log(1 + torch.exp(-z.abs()))
loss = loss.sum()/w.sum()
return loss
Expand Down Expand Up @@ -577,20 +591,16 @@ def forward(self, logit, target):
# alpha = alpha * (1 - self.alpha)
# alpha = alpha.scatter_(1, target.long(), self.alpha)
epsilon = 1e-10
alpha = self.alpha
if alpha.device != input.device:
alpha = alpha.to(input.device)
alpha = self.alpha.to(logit.device)

idx = target.cpu().long()

one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)
one_hot_key = one_hot_key.to(logit.device)

if self.smooth:
one_hot_key = torch.clamp(
one_hot_key, self.smooth / (self.num_class - 1), 1.0 - self.smooth)
one_hot_key = torch.clamp(one_hot_key, self.smooth / (self.num_class - 1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + epsilon
logpt = pt.log()

Expand Down Expand Up @@ -636,8 +646,16 @@ def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
def forward(self, inputs, targets): # variables
P = F.softmax(inputs)

b,c,h,w = inputs.size()
class_mask = Variable(torch.zeros([b,c+1,h,w]).cuda())
if len(inputs.size()) == 3:
torch_out = torch.zeros(inputs.size())
else:
b,c,h,w = inputs.size()
torch_out = torch.zeros([b,c+1,h,w])

if inputs.is_cuda:
torch_out = torch_out.cuda()

class_mask = Variable(torch_out)
class_mask.scatter_(1, targets.long(), 1.)
class_mask = class_mask[:,:-1,:,:]

Expand All @@ -659,30 +677,23 @@ def forward(self, inputs, targets): # variables
# -------- #

# -------- #
# Source: https://becominghuman.ai/investigating-focal-and-dice-loss-for-the-kaggle-2018-data-science-bowl-65fb9af4f36c
class BinaryFocalLoss3(nn.Module):
# Source: https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/4
class BinaryFocalLoss(nn.Module):
'''
Implementation of binary focal loss. For multi-class focal loss use one of the other implementations.
gamma = 0 is equivalent to BinaryCrossEntropy Loss
'''
def __init__(self, gamma=0.5):
def __init__(self, gamma=1.333, eps=1e-6, alpha=1.0):
super().__init__()
self.gamma = gamma
self.eps = eps

def forward(self, input, target):
input = input.squeeze()
target = target.squeeze()
# Inspired by the implementation of binary_cross_entropy_with_logits
if not (target.size() == input.size()):
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))

max_val = (-input).clamp(min=0)
loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

# This formula gives us the log sigmoid of 1-p if y is 0 and of p if y is 1
invprobs = F.logsigmoid(-input * (target * 2 - 1))
loss = (invprobs * self.gamma).exp() * loss

return loss.mean()
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss) # prevents nans when probability 0
F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
return F_loss.mean()
# -------- #

# ==== Additional Losses === #
Expand Down Expand Up @@ -863,7 +874,11 @@ class ComboSemsegLoss(nn.Module):
def __init__(self, use_running_mean=False, bce_weight=1, dice_weight=1, eps=1e-10, gamma=0.9, combined_loss_only=False):
super().__init__()

self.nll_loss = nn.BCEWithLogitsLoss()
'''
Note: BCEWithLogitsLoss already performs a torch.sigmoid(pred)
before applying BCE!
'''
self.bce_logits_loss = nn.BCEWithLogitsLoss()

self.dice_weight = dice_weight
self.bce_weight = bce_weight
Expand All @@ -885,14 +900,15 @@ def reset_parameters(self):
self.running_dice_loss.zero_()

def forward(self, outputs, targets):
# inputs and targets are assumed to be BxCxWxH
# inputs and targets are assumed to be BxCxWxH (batch, color, width, height)
outputs = outputs.squeeze() # necessary in case we're dealing with binary segmentation (color dim of 1)
assert len(outputs.shape) == len(targets.shape)
# assert that B, W and H are the same
assert outputs.size(-0) == targets.size(-0)
assert outputs.size(-1) == targets.size(-1)
assert outputs.size(-2) == targets.size(-2)

bce_loss = self.nll_loss(outputs, targets)
bce_loss = self.bce_logits_loss(outputs, targets)

dice_target = (targets == 1).float()
dice_output = F.sigmoid(outputs)
Expand Down Expand Up @@ -1228,7 +1244,7 @@ def _get_batch_label_vector(target, nclass):


# Source: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/TverskyLoss/binarytverskyloss.py (MIT)
class FocalBinaryTverskyLoss(Function):
class FocalBinaryTverskyFunc(Function):
"""
Focal Tversky Loss as defined in `this paper <https://arxiv.org/abs/1810.07842>`_
Expand Down Expand Up @@ -1344,7 +1360,7 @@ class MultiTverskyLoss(nn.Module):
:param weights (Tensor, optional): a manual rescaling weight given to each class. If given, it has to be a Tensor of size `C`
"""

def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, weights=None):
def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean', weights=None):
"""
:param alpha (Tensor, float, optional): controls the penalty for false positives.
:param beta (Tensor, float, optional): controls the penalty for false negative.
Expand All @@ -1356,10 +1372,10 @@ def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, weights=None):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.reduction = reduction
self.weights = weights

def forward(self, inputs, targets):

num_class = inputs.size(1)
weight_losses = 0.0
if self.weights is not None:
Expand All @@ -1372,10 +1388,41 @@ def forward(self, inputs, targets):
input_idx = input_slices[idx]
input_idx = torch.cat((1 - input_idx, input_idx), dim=1)
target_idx = (targets == idx) * 1
loss_func = FocalBinaryTverskyLoss(self.alpha, self.beta, self.gamma)
loss_func = FocalBinaryTverskyFunc(self.alpha, self.beta, self.gamma, self.reduction)
loss_idx = loss_func(input_idx, target_idx)
weight_losses+=loss_idx * weights[idx]
# loss = torch.Tensor(weight_losses)
# loss = loss.to(inputs.device)
# loss = torch.sum(loss)
return weight_losses


class FocalBinaryTverskyLoss(MultiTverskyLoss):
"""
Binary version of Focal Tversky Loss as defined in `this paper <https://arxiv.org/abs/1810.07842>`_
`Authors' implementation <https://github.com/nabsabraham/focal-tversky-unet>`_ in Keras.
Params:
:param alpha: controls the penalty for false positives.
:param beta: penalty for false negative.
:param gamma : focal coefficient range[1,3]
:param reduction: return mode
Notes:
alpha = beta = 0.5 => dice coeff
alpha = beta = 1 => tanimoto coeff
alpha + beta = 1 => F beta coeff
add focal index -> loss=(1-T_index)**(1/gamma)
"""

def __init__(self, alpha=0.5, beta=0.7, gamma=1.33333, reduction='mean'):
"""
:param alpha (Tensor, float, optional): controls the penalty for false positives.
:param beta (Tensor, float, optional): controls the penalty for false negative.
:param gamma (Tensor, float, optional): focal coefficient
"""
super().__init__(alpha, beta, gamma, reduction)

def forward(self, inputs, targets):
return super().forward(inputs, targets.unsqueeze(1))
5 changes: 5 additions & 0 deletions pywick/modules/module_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def fit(self,
self._optimizer.zero_grad()
output_batch = fit_forward_fn(input_batch)
loss = fit_loss_fn(output_batch, target_batch)
assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'
loss.backward()
self._optimizer.step()
# ---------------------------------------------
Expand Down Expand Up @@ -445,7 +446,9 @@ def fit_loader(self,
# ---------------------------------------------
self._optimizer.zero_grad()
output_batch = fit_forward_fn(input_batch)

loss = fit_loss_fn(output_batch, target_batch)
assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'
loss.backward()
self._optimizer.step()
# ---------------------------------------------
Expand Down Expand Up @@ -601,6 +604,7 @@ def evaluate(self,
self._optimizer.zero_grad()
output_batch = eval_forward_fn(input_batch)
loss = eval_loss_fn(output_batch, target_batch)
assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'

if conditions_container:
cond_logs = conditions_container(CondType.POST, epoch_num=None, batch_num=batch_idx, net=self.model, input_batch=input_batch, output_batch=output_batch, target_batch=target_batch)
Expand Down Expand Up @@ -657,6 +661,7 @@ def evaluate_loader(self, loader, eval_helper_name=None, verbose=1):
self._optimizer.zero_grad()
output_batch = eval_forward_fn(input_batch)
loss = eval_loss_fn(output_batch, target_batch)
assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.'

if conditions_container:
cond_logs = conditions_container(CondType.POST, epoch_num=None, batch_num=batch_idx, net=self.model, input_batch=input_batch, output_batch=output_batch, target_batch=target_batch)
Expand Down

0 comments on commit e79bab3

Please sign in to comment.