# 03_loss_function.ipynb

Implement loss functions of SSD.
including Hard Negative Mining.

###  This uses utils.match

In [1]:
# import package and functions
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.match import match

### Implement loss function - MultiBoxLoss class

In [2]:
class MultiBoxLoss(nn.Module):
    """loss function of SSD"""

    def __init__(self, jaccard_thresh=0.5, neg_pos=3, device='cpu'):
        super(MultiBoxLoss, self).__init__()
        self.jaccard_thresh = jaccard_thresh  # 0.5 : threshold of jaccard index in function the match
        self.negpos_ratio = neg_pos  # 3:1 : ratio of Hard Negative Mining neg:pos
        self.device = device  # CPU or GPU

    def forward(self, predictions, targets):
        """
        calculation of loss function

        Parameters
        ----------
        predictions : output of SSD network in the training (tuple)
            (loc=torch.Size([num_batch, 8732, 4]), conf=torch.Size([num_batch, 8732, 21]), dbox_list=torch.Size [8732,4])。

        targets : [num_batch, num_objs, 5]
            5 is GT annotation : [xmin, ymin, xmax, ymax, label_ind]

        Returns
        -------
        loss_l : Tensor
            loss of loc
        loss_c : Tensor
            loss of conf

        """

        # Disjoint the output of tuppled SSD model
        loc_data, conf_data, dbox_list = predictions

        # number of elements
        num_batch = loc_data.size(0)    # mini-batch size
        num_dbox = loc_data.size(1)     # num of DBox= 8732
        num_classes = conf_data.size(2) # num of classes = 21

        # create tentative variables for loss calculation 
        # conf_t_label：Store the label of the nearest correct BBox for each DBox
        # loc_t:Store the location of the nearest correct BBox for each DBox
        conf_t_label = torch.LongTensor(num_batch, num_dbox).to(self.device)
        loc_t = torch.Tensor(num_batch, num_dbox, 4).to(self.device)

        # Override the results of match of DBox and 'targets' (correct annotation target)
        #  on loc_t and conf_t_label
        for idx in range(num_batch):  # mini-batch loop

            # Get correct annotation of BBox and label in the current mini-batch
            truths = targets[idx][:, :-1].to(self.device)  # BBox
            # labels [label of obj1, label of obj2, …]
            labels = targets[idx][:, -1].to(self.device)

            # prepare a new variable for default box
            dbox = dbox_list.to(self.device)

            # Execute function "match" and update the contents of loc_t and conf_t_label
            # （detail） For each BBox:
            # loc_t: Overwrite it to the location of the nearest correct BBox
            # conf_t_label：Overwrite it to the label of the nearest correct BBox
            # However, if the jaccard overlap with the nearest BBox is less than 0.5,
            #  the correct answer BBox label conf_t_label is 0 (the background class)
  
            variance = [0.1, 0.2]
            # This variance is used to calculate the correction from DBox to BBox
            match(self.jaccard_thresh, truths, dbox,
                  variance, labels, loc_t, conf_t_label, idx)

        # ----------
        # Calculation of location-related loss (loss_l) with Smooth L1
        #  (only for the offset of the DBox that found the object)
        # ----------
        # Mask to extract BBox from an object
        pos_mask = conf_t_label > 0  # torch.Size([num_batch, 8732])

        # reshape pos_mask to that of loc_data
        pos_idx = pos_mask.unsqueeze(pos_mask.dim()).expand_as(loc_data)

        # get loc_data and loc_t (training data) of Positive DBox
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)

        # Compute the loss of offset information (loc_t) for the Positive DBox that found the object
        loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')

        # ----------
        # Calculation of class prediction loss (loss_c) with cross-entropy
        #
        # Here, since there are overwhelming number of DBoxes with the correct background class, 
        # we perform Hard Negative Mining (HNM) so that the ratio of positive DBOX to background DBOX is 1:3.
        # Furthermore, DBoxes that are predicted as background and whose losses are small, 
        #     are excluded from the calculation of class prediction loss (loss_c)   
        # ----------
        batch_conf = conf_data.view(-1, num_classes)

        # calc for class loss
        loss_c = F.cross_entropy(
            batch_conf, conf_t_label.view(-1), reduction='none')

        # -----------------
        # Select negative-DBox used in Hard negative mining (HNM)
        # -----------------

        # Set the loss_c=0 if object is found
        #  (note) object has an label >0, background label = 0 
        num_pos = pos_mask.long().sum(1, keepdim=True)  # num of obj class in the min-batch
        loss_c = loss_c.view(num_batch, -1)  # torch.Size([num_batch, 8732])
        loss_c[pos_mask] = 0  # Set the loss_c=0 if object is found

        # perform Hard Negative Mining (HNM)
        # Find idx_rank, which is the rank of loss_c of each DBox
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)

        # get the number of background DBOX, 'num_neg'
        # According to HNM, this number is three times larger (self.negpos_ratio) than that of positive DBox.
        # If this is more than the total number of DBOXs, set it to this value.
        num_neg = torch.clamp(num_pos*self.negpos_ratio, max=num_dbox)

        # idx_rank contains the ranking of each DBox loss
        # Select negative-Dbox with large recognition loss (recognized far from background) 
        # torch.Size([num_batch, 8732])
        neg_mask = idx_rank < (num_neg).expand_as(idx_rank)

        # -----------------
        # （finish - selecting negative-Dbox use in HNM)
        # -----------------

        # reshape mask to meet the dimension of 'conf_data'
        # pos_idx_mask: mask for Positive DBox (conf)
        # neg_idx_mask: mask for Negative DBox selected in HNM (conf)
        # pos_mask：torch.Size([num_batch, 8732])→pos_idx_mask：torch.Size([num_batch, 8732, 21])
        pos_idx_mask = pos_mask.unsqueeze(2).expand_as(conf_data)
        neg_idx_mask = neg_mask.unsqueeze(2).expand_as(conf_data)

        # concatenate positiveDB + negative-DB(selected HNM) formed conf_hnm
        # the size of conf_hnm: torch.Size([num_pos+num_neg, 21])
        conf_hnm = conf_data[(pos_idx_mask+neg_idx_mask).gt(0)].view(-1, num_classes)
        # (note) gt: greater than (>) : this extracts index whose mask=1

        # same as above, create conf_t_label_hnm: torch.Size([pos+neg])
        conf_t_label_hnm = conf_t_label[(pos_mask+neg_mask).gt(0)]

        # calculate loss for confidence (recognition)
        loss_c = F.cross_entropy(conf_hnm, conf_t_label_hnm, reduction='sum')

        # calculate loss_l and loss_c (divided by number of positive DBox; N)
        N = num_pos.sum()
        loss_l /= N
        loss_c /= N

        return loss_l, loss_c
