diff --git a/README.md b/README.md index 482cf5a..bb091b7 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,13 @@ Hey, [check this out](https://pywick.readthedocs.io/en/latest/), we now have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a work in progress though so apologies for anything that's broken. +## What's New (highlights) +- **Aug. 1, 2019** + - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet + - New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses + - Major restructuring and standardization of NN models and loading functionality + - General bug fixes and code improvements + ## Install `pip install pywick` diff --git a/docs/source/README.md b/docs/source/README.md index 482cf5a..bb091b7 100644 --- a/docs/source/README.md +++ b/docs/source/README.md @@ -28,6 +28,13 @@ Hey, [check this out](https://pywick.readthedocs.io/en/latest/), we now have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a work in progress though so apologies for anything that's broken. +## What's New (highlights) +- **Aug. 1, 2019** + - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet + - New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses + - Major restructuring and standardization of NN models and loading functionality + - General bug fixes and code improvements + ## Install `pip install pywick` diff --git a/pywick/losses.py b/pywick/losses.py index 2dfde1e..aca0223 100644 --- a/pywick/losses.py +++ b/pywick/losses.py @@ -12,6 +12,14 @@ Make sure to read the documentation and notes (in the code) for each loss to understand how it is applied. `Read this blog post `_ + +Note: + ``` + Logit is the vector of raw (non-normalized) predictions that a classification model generates, which is ordinarily then passed to a normalization function. + If the model is solving a multi-class classification problem, logits typically become an input to the softmax function. The softmax function then generates + a vector of (normalized) probabilities with one value for each possible class. + ``` +For example, BCEWithLogitsLoss is a BCE that accepts R((-inf, inf)) and automatically applies torch.sigmoid to convert it to ([0,1]) space. """ ## Various loss calculation functions ## @@ -280,9 +288,11 @@ def lovasz_single(logit, label, prox=False, max_steps=20, debug={}): loss = lovasz_binary(margins, target, prox, max_steps, debug=debug) return loss -# WARNING THIS IS VERY SLOW FOR SOME REASON!! -def dice_coefficient(logit, label, isCuda = True): + +def dice_coefficient(logit, label, isCuda=True): ''' + WARNING THIS IS VERY SLOW FOR SOME REASON!! + :param logit: calculated guess (expects torch.Tensor) :param label: truth label (expects torch.Tensor) :return: dice coefficient @@ -368,8 +378,9 @@ def forward(self, logits, targets): class SoftDiceLoss(nn.Module): - def __init__(self, weight=None, size_average=True): + def __init__(self, smooth=1.0): super(SoftDiceLoss, self).__init__() + self.smooth = smooth def forward(self, logits, targets): #print('logits: {}, targets: {}'.format(logits.size(), targets.size())) @@ -379,9 +390,9 @@ def forward(self, logits, targets): m2 = targets.view(num, -1) intersection = (m1 * m2) - smooth = 1. + # smooth = 1. - score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) + score = 2. * (intersection.sum(1) + self.smooth) / (m1.sum(1) + m2.sum(1) + self.smooth) score = 1 - score.sum() / num return score @@ -428,7 +439,7 @@ def forward(self, logits, targets): class BCEDiceTL1Loss(nn.Module): def __init__(self, threshold=0.5): super(BCEDiceTL1Loss, self).__init__() - self.bce = BCELoss2d() + self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None) self.dice = SoftDiceLoss() self.tl1 = ThresholdedL1Loss(threshold=threshold) @@ -438,19 +449,22 @@ def forward(self, logits, targets): class BCEDiceFocalLoss(nn.Module): ''' - :param l: l-parameter for FocalLoss - :param weight_of_focal: How to weigh the focal loss (between 0 - 1) + :param num_classes: number of classes + :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more + focus on hard misclassified example + :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. + :param weights: (list(), default = [1,1,1]) Optional weighing (0.0-1.0) of the losses in order of [bce, dice, focal] ''' - def __init__(self, l=0.5, weight_of_focal=1.): + def __init__(self, focal_param, weights=[1.0,1.0,1.0]): super(BCEDiceFocalLoss, self).__init__() - # self.bce = BCELoss2d() - # self.dice = SoftDiceLoss() - self.dice = BCELoss2d() - self.focal = FocalLoss(l=l) - self.weight_of_focal = weight_of_focal + self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None) + self.dice = SoftDiceLoss() + self.focal = FocalLoss(l=focal_param) + self.weights = weights def forward(self, logits, targets): - return self.dice(logits, targets) + self.weight_of_focal * self.focal(logits, targets) + logits = logits.squeeze() + return self.weights[0] * self.bce(logits, targets) + self.weights[1] * self.dice(logits, targets) + self.weights[2] * self.focal(logits.unsqueeze(1), targets.unsqueeze(1)) class BCEDiceLoss(nn.Module): @@ -469,8 +483,8 @@ def __init__(self): def forward(self, logits, labels, weights): w = weights.view(-1) - z = logits.view (-1) - t = labels.view (-1) + z = logits.view(-1) + t = labels.view(-1) loss = w*z.clamp(min=0) - w*z*t + w*torch.log(1 + torch.exp(-z.abs())) loss = loss.sum()/w.sum() return loss @@ -577,20 +591,16 @@ def forward(self, logit, target): # alpha = alpha * (1 - self.alpha) # alpha = alpha.scatter_(1, target.long(), self.alpha) epsilon = 1e-10 - alpha = self.alpha - if alpha.device != input.device: - alpha = alpha.to(input.device) + alpha = self.alpha.to(logit.device) idx = target.cpu().long() one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_() one_hot_key = one_hot_key.scatter_(1, idx, 1) - if one_hot_key.device != logit.device: - one_hot_key = one_hot_key.to(logit.device) + one_hot_key = one_hot_key.to(logit.device) if self.smooth: - one_hot_key = torch.clamp( - one_hot_key, self.smooth / (self.num_class - 1), 1.0 - self.smooth) + one_hot_key = torch.clamp(one_hot_key, self.smooth / (self.num_class - 1), 1.0 - self.smooth) pt = (one_hot_key * logit).sum(1) + epsilon logpt = pt.log() @@ -636,8 +646,16 @@ def __init__(self, class_num, alpha=None, gamma=2, size_average=True): def forward(self, inputs, targets): # variables P = F.softmax(inputs) - b,c,h,w = inputs.size() - class_mask = Variable(torch.zeros([b,c+1,h,w]).cuda()) + if len(inputs.size()) == 3: + torch_out = torch.zeros(inputs.size()) + else: + b,c,h,w = inputs.size() + torch_out = torch.zeros([b,c+1,h,w]) + + if inputs.is_cuda: + torch_out = torch_out.cuda() + + class_mask = Variable(torch_out) class_mask.scatter_(1, targets.long(), 1.) class_mask = class_mask[:,:-1,:,:] @@ -659,30 +677,23 @@ def forward(self, inputs, targets): # variables # -------- # # -------- # -# Source: https://becominghuman.ai/investigating-focal-and-dice-loss-for-the-kaggle-2018-data-science-bowl-65fb9af4f36c -class BinaryFocalLoss3(nn.Module): +# Source: https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/4 +class BinaryFocalLoss(nn.Module): ''' + Implementation of binary focal loss. For multi-class focal loss use one of the other implementations. + gamma = 0 is equivalent to BinaryCrossEntropy Loss ''' - def __init__(self, gamma=0.5): + def __init__(self, gamma=1.333, eps=1e-6, alpha=1.0): super().__init__() self.gamma = gamma + self.eps = eps - def forward(self, input, target): - input = input.squeeze() - target = target.squeeze() - # Inspired by the implementation of binary_cross_entropy_with_logits - if not (target.size() == input.size()): - raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) - - max_val = (-input).clamp(min=0) - loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() - - # This formula gives us the log sigmoid of 1-p if y is 0 and of p if y is 1 - invprobs = F.logsigmoid(-input * (target * 2 - 1)) - loss = (invprobs * self.gamma).exp() * loss - - return loss.mean() + def forward(self, inputs, targets): + BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') + pt = torch.exp(-BCE_loss) # prevents nans when probability 0 + F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss + return F_loss.mean() # -------- # # ==== Additional Losses === # @@ -863,7 +874,11 @@ class ComboSemsegLoss(nn.Module): def __init__(self, use_running_mean=False, bce_weight=1, dice_weight=1, eps=1e-10, gamma=0.9, combined_loss_only=False): super().__init__() - self.nll_loss = nn.BCEWithLogitsLoss() + ''' + Note: BCEWithLogitsLoss already performs a torch.sigmoid(pred) + before applying BCE! + ''' + self.bce_logits_loss = nn.BCEWithLogitsLoss() self.dice_weight = dice_weight self.bce_weight = bce_weight @@ -885,14 +900,15 @@ def reset_parameters(self): self.running_dice_loss.zero_() def forward(self, outputs, targets): - # inputs and targets are assumed to be BxCxWxH + # inputs and targets are assumed to be BxCxWxH (batch, color, width, height) + outputs = outputs.squeeze() # necessary in case we're dealing with binary segmentation (color dim of 1) assert len(outputs.shape) == len(targets.shape) # assert that B, W and H are the same assert outputs.size(-0) == targets.size(-0) assert outputs.size(-1) == targets.size(-1) assert outputs.size(-2) == targets.size(-2) - bce_loss = self.nll_loss(outputs, targets) + bce_loss = self.bce_logits_loss(outputs, targets) dice_target = (targets == 1).float() dice_output = F.sigmoid(outputs) @@ -1228,7 +1244,7 @@ def _get_batch_label_vector(target, nclass): # Source: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/TverskyLoss/binarytverskyloss.py (MIT) -class FocalBinaryTverskyLoss(Function): +class FocalBinaryTverskyFunc(Function): """ Focal Tversky Loss as defined in `this paper `_ @@ -1344,7 +1360,7 @@ class MultiTverskyLoss(nn.Module): :param weights (Tensor, optional): a manual rescaling weight given to each class. If given, it has to be a Tensor of size `C` """ - def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, weights=None): + def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean', weights=None): """ :param alpha (Tensor, float, optional): controls the penalty for false positives. :param beta (Tensor, float, optional): controls the penalty for false negative. @@ -1356,10 +1372,10 @@ def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, weights=None): self.alpha = alpha self.beta = beta self.gamma = gamma + self.reduction = reduction self.weights = weights def forward(self, inputs, targets): - num_class = inputs.size(1) weight_losses = 0.0 if self.weights is not None: @@ -1372,10 +1388,41 @@ def forward(self, inputs, targets): input_idx = input_slices[idx] input_idx = torch.cat((1 - input_idx, input_idx), dim=1) target_idx = (targets == idx) * 1 - loss_func = FocalBinaryTverskyLoss(self.alpha, self.beta, self.gamma) + loss_func = FocalBinaryTverskyFunc(self.alpha, self.beta, self.gamma, self.reduction) loss_idx = loss_func(input_idx, target_idx) weight_losses+=loss_idx * weights[idx] # loss = torch.Tensor(weight_losses) # loss = loss.to(inputs.device) # loss = torch.sum(loss) return weight_losses + + +class FocalBinaryTverskyLoss(MultiTverskyLoss): + """ + Binary version of Focal Tversky Loss as defined in `this paper `_ + + `Authors' implementation `_ in Keras. + + Params: + :param alpha: controls the penalty for false positives. + :param beta: penalty for false negative. + :param gamma : focal coefficient range[1,3] + :param reduction: return mode + + Notes: + alpha = beta = 0.5 => dice coeff + alpha = beta = 1 => tanimoto coeff + alpha + beta = 1 => F beta coeff + add focal index -> loss=(1-T_index)**(1/gamma) + """ + + def __init__(self, alpha=0.5, beta=0.7, gamma=1.33333, reduction='mean'): + """ + :param alpha (Tensor, float, optional): controls the penalty for false positives. + :param beta (Tensor, float, optional): controls the penalty for false negative. + :param gamma (Tensor, float, optional): focal coefficient + """ + super().__init__(alpha, beta, gamma, reduction) + + def forward(self, inputs, targets): + return super().forward(inputs, targets.unsqueeze(1)) diff --git a/pywick/modules/module_trainer.py b/pywick/modules/module_trainer.py index 2f692ef..61b43eb 100644 --- a/pywick/modules/module_trainer.py +++ b/pywick/modules/module_trainer.py @@ -338,6 +338,7 @@ def fit(self, self._optimizer.zero_grad() output_batch = fit_forward_fn(input_batch) loss = fit_loss_fn(output_batch, target_batch) + assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.' loss.backward() self._optimizer.step() # --------------------------------------------- @@ -445,7 +446,9 @@ def fit_loader(self, # --------------------------------------------- self._optimizer.zero_grad() output_batch = fit_forward_fn(input_batch) + loss = fit_loss_fn(output_batch, target_batch) + assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.' loss.backward() self._optimizer.step() # --------------------------------------------- @@ -601,6 +604,7 @@ def evaluate(self, self._optimizer.zero_grad() output_batch = eval_forward_fn(input_batch) loss = eval_loss_fn(output_batch, target_batch) + assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.' if conditions_container: cond_logs = conditions_container(CondType.POST, epoch_num=None, batch_num=batch_idx, net=self.model, input_batch=input_batch, output_batch=output_batch, target_batch=target_batch) @@ -657,6 +661,7 @@ def evaluate_loader(self, loader, eval_helper_name=None, verbose=1): self._optimizer.zero_grad() output_batch = eval_forward_fn(input_batch) loss = eval_loss_fn(output_batch, target_batch) + assert not math.isnan(loss), 'Assertion failed: Loss is not NaN.' if conditions_container: cond_logs = conditions_container(CondType.POST, epoch_num=None, batch_num=batch_idx, net=self.model, input_batch=input_batch, output_batch=output_batch, target_batch=target_batch)