In [18]:
import numpy as np
import torch
import torch.nn as nn

In [51]:
def calc_iou(a, b): # a is n*4 matrix whereas b is m*4 matrix where i_th row is [x_min, y_min, x_max, y_max] for both
    # a denotes predicted bounding box whereas b denotes ground truth bounding box, or vice versa.
    area_b = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) # area is array of size m.

    iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
    ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
    # iw and ih are n*m matrices whose (i, j)_th element denote respectively the width and height of the intersecting
    # rectangle between i_th bounding box of a and j_th bounding of b. 
    # 2 disconnected bounding boxes will have negative width or height of intersecting rectangle.
    iw = torch.clamp(iw, min=0) # this increases any negative element of iw to 0.
    ih = torch.clamp(ih, min=0) # this increases any negative element of ih to 0.

    intersection = iw * ih # intersection is n*m matrix giving intersecting area
    
    ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area_b - intersection
    # ua is n*m matrix whose (i, j)_th element gives the area of union of i_th bounding box of a
    # and j_th bounding box of b.
    ua = torch.clamp(ua, min=1e-8) # we want a very small positive value to avoid division by 0.

    IoU = intersection / ua # IoU is n*m matrix

    return IoU

# p = batch_size
# q = no. of predicted bounding box
# r = no. of classes
# t = no. of ground truth bounding box
# s = no. of ground truth bounding boxes recognising foreground in a particular image
# u = no. of predicted bounding boxes in a particular image whose IoU with atleast one ground truth bounding box is greater than or equal to 0.5

In [59]:
class FocalLoss(nn.Module):
    #def __init__(self):

    def forward(self, classifications, regressions, anchors, annotations):
        # classifications is a p*q*r array where p is batch_size, i.e. total number of images in each batch of SGD, q is total number
        # of predicted bounding box for each image (assume same number of predicted bounding box for each image)
        # which is also same as number of anchor boxes and r is total number of classes. Thus its (i, j, k)_th
        # element depicts the probability that the j_th predicted bounding box of i_th image belongs to k_th class.

        # regressions is a p*q*4 array whose (i, j)_th element is [x_min, y_min, x_max, y_max] for j_th predicted
        # bounding box for i_th image

        # anchors is a p*q*4 array whose (i, j)_th element is [x_min, y_min, x_max, y_max] for j_th anchor box in
        # i_th image. Anchor boxes are predefined before the training starts.

        # annotations is p*t*5 array where t is the number of annotations, which is maximum number of objects in
        # any image which is same as the number of ground truth bounding box in each image. Its (i, j)_th element
        # is [x_min, y_min, x_max, y_max, class_id] for j_th ground truth
        # bounding box in i_th image. class_id will be -1 for background, i.e. no actual object. 

        alpha = 0.25 # balancing factor for positive and negative samples
        gamma = 2.0 # focussing parameter
        batch_size = classifications.shape[0] # =p
        classification_losses = [] # array initialisation
        regression_losses = [] # They will store the losses for each image

        anchor = anchors[0, :, :] # We assume same anchor boxes for all the images, so it is matrix of size t*5

        anchor_widths  = anchor[:, 2] - anchor[:, 0] # array of size q denoting width
        anchor_heights = anchor[:, 3] - anchor[:, 1] # array of size q denoting height
        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths # array of size q denoting abscissa of centre
        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights # array of size q denoting ordinate of centre

        for j in range(batch_size): # for j = 1 to p
            classification = classifications[j, :, :] # matrix of size q*r
            regression = regressions[j, :, :] # matrix of size q*4

            bbox_annotation = annotations[j, :, :] # matrix of size t*5
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] # remove those ground truth bounding boxes
            # which recognise no object. Let the new size be s*5, so s<=t

            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) # assume that all the probabilities lie
            # between 0.0001 and 0.9999 so that when we take log(x) or log(1-x) then numerical stability is maintained

            if bbox_annotation.shape[0] == 0: # if there is no ground truth bounding box
                if torch.cuda.is_available(): # if GPU is available
                    alpha_factor = torch.ones(classification.shape).cuda() * alpha # alpha_factor is a tensor of
                    # same shape as classfication and all its elements are alpha, i.e. 0.25. Element wise
                    # multiplication is performed by GPU.

                    alpha_factor = 1. - alpha_factor # It is a q*r matrix filled with 0.75
                    focal_weight = classification # q*r matrix
                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma) # (1-α) * (p^γ)

                    bce = -(torch.log(1.0 - classification)) # Binary CrossEntropy Loss = -log(1-p)

                    # cls_loss = focal_weight * torch.pow(bce, gamma)
                    cls_loss = focal_weight * bce # -(1-α) * (p^γ) * log(1-p)
                    classification_losses.append(cls_loss.sum()) # sum of all q*r elements
                    regression_losses.append(torch.tensor(0).float().cuda()) # Since there is no ground truth
                    # bounding box so regression loss is 0

                else: # if GPU is not available, all calculations are exactly same
                    alpha_factor = torch.ones(classification.shape) * alpha

                    alpha_factor = 1. - alpha_factor
                    focal_weight = classification
                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

                    bce = -(torch.log(1.0 - classification))

                    # cls_loss = focal_weight * torch.pow(bce, gamma)
                    cls_loss = focal_weight * bce
                    classification_losses.append(cls_loss.sum())
                    regression_losses.append(torch.tensor(0).float())

                continue # go to next image, i.e. next iteration in for loop, i.e. j++
            # else there is atleast 1 ground truth box
            
            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4])
            # IoU is q*s matrix for finding IoU between i_th predicted and j_th ground truth bounding box.
            IoU_max, IoU_argmax = torch.max(IoU, dim=1) # IoU_max and IoU_argmax are arrays of size q whose
            # i_th element denote maximum IoU and the index of ground truth bounding box having maximum IoU
            # respectively with the i_th predicted bounding box.

            #import pdb
            #pdb.set_trace()

            # compute the loss for classification
            targets = torch.ones(classification.shape) * -1 # targets is q*r matrix filled with -1

            if torch.cuda.is_available():
                targets = targets.cuda() # Why dont we do similar kind of thing in no ground truth bounding box?

            targets[torch.lt(IoU_max, 0.4), :] = 0 # if i_th element of IoU_max is less than 0.4 then all elements
            # of i_th row of targets is changed to 0

            positive_indices = torch.ge(IoU_max, 0.5) # positive_indices is an array of size q whose i_th element
            # will be 1 if i_th element of IoU_max >= 0.5 else 0.

            num_positive_anchors = positive_indices.sum() # it finds the number of indices of IoU_max having
            # value greater than or equal to 0.5

            assigned_annotations = bbox_annotation[IoU_argmax, :] # assigned_annotations is a q*5 matrix whose i_th
            # row is the IoU_argmax[i]_th row of bbox_notation. Since IoU_argmax contains the index of the groun truth
            # bounding box most closely associated with the given index so it is guaranteed to be greater than or
            # equal to 0 but less than t. Note that it is possible that a particular ground truth bounding box is most
            # closely associated with multiple or no predicted bounding box. Thus some of the rows in
            # assigned_annotations may be repeated or some rows of bbox_annotation may not be present in
            # assigned_annotations. Thus i_th row of assigned_annotations returns [x_min, y_min, x_max, y_max, 
            # class_id] of the ground truth bounding box which is most closely associated to the i_th predicted
            # bounding box.

            targets[positive_indices, :] = 0 # if i_th element of positive_indices is 1 then all elements of i_th row
            # of targets will become 0.
            targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1 # if i_th element of
            # positive_indices is 1 and (i, 4)_th element of assigned_annotations is j then (i, j)_th element of
            # targets will become 1. .long() is used to convert class labels into integers.
            
            # Thus, (i, j)_th element of targets will be:-
            #     1 iff i_th predicted bounding box belongs to j_th class. This will happen iff i_th element of 
            #         IoU_max is greater than or equal to 0.5 and i_th element of IoU_argmax is j.
            #     -1 iff i_th predicted bounding box has IoU_max less than 0.5 but greater than or equal to 0.4. 
            #         This denotes that it is difficult to make predictions about class.
            #     0 Otherwise. This denotes that most probably i_th predicted bounding box does not belong to j_th
            #         class.

            if torch.cuda.is_available():
                alpha_factor = torch.ones(targets.shape).cuda() * alpha
            else: # alpha_factor is a q*r matrix filled with 0.25
                alpha_factor = torch.ones(targets.shape) * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
            # if (i, j)_th element of targets is not 1 then (i, j)_th element of alpha_factor will be subtracted from 1.
            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
            # focal_weight is q*r matrix whose (i, j)_th element will be (i, j)_th element of classification iff
            # (i, j)_th element of targets is not 1, else it will be 1 - (i, j)_th element of classification
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
            # bce is q*r matrix

            # cls_loss = focal_weight * torch.pow(bce, gamma)
            cls_loss = focal_weight * bce # cls_loss is q*r matrix

            if torch.cuda.is_available():
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
            else: # if (i, j)_th element of targets was -1, then (i, j)_th element of cls_loss will be 0
            # This step is done because it was difficult to make class prediction for this bounding box
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape))

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
            # We append average of classification loss for this image. In case IoU_max < 0.5 for all elements we dont
            # want division by 0

            # compute the loss for regression

            if positive_indices.sum() > 0: # If atleast one element of IoU_max is greater than or equal to 0.5
                assigned_annotations = assigned_annotations[positive_indices, :] # Remove those predicted bounding
                # boxes whose corresponding IoU_max is less than 0.5. Let the new size be u*5. So u <= q.
                
                # Similarly do the above thing for predicted bounding boxes' dimensions and centre coordinates.
                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices] # These are arrays of size u.

                # Similarly we do above thing for ground truth bounding boxes
                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights # These are arrays of size u.

                # clip widths to 1
                gt_widths  = torch.clamp(gt_widths, min=1) 
                gt_heights = torch.clamp(gt_heights, min=1)

                # Loss terms:-
                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh)) # target becomes a matrix of
                # size 4 * u
                targets = targets.t() # it becomes a matrix of size u * 4

                if torch.cuda.is_available():
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else: # all elements of the 0_th and 1_st coloum of targets is divided by 0.1 whereas those of 2nd and
                # 3rd coloum are divided by 0.2.
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                negative_indices = 1 + (~positive_indices) # It is an array of size q whose i_th element os 1 if i_th
                # element of positive_indices is 1 else it is 2. ~ is used to perform element wise bitwise NOT

                regression_diff = torch.abs(targets - regression[positive_indices, :]) # regression_diff is a matrix
                # size u*4 whose i_th row gives the absolute difference between predicted and error values of 
                # dimensions and centre coordinates of bounding box

                regression_loss = torch.where( # regression loss is u*4 matrix whose (i, j)_th element will be
                    torch.le(regression_diff, 1.0 / 9.0), # 4.5 * regression_diff[i, j]^2 iff regression_diff[i, j]
                    0.5 * 9.0 * torch.pow(regression_diff, 2), # is less than 1/9, else it will be 
                    regression_diff - 0.5 / 9.0 # regression_diff[i, j] - 1/18. This is HUBER LOSS.
                )
                regression_losses.append(regression_loss.mean()) 
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())
        # All elements of classification_losses and regression_losses are scalars.
        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)
        # torch.stack() is used to convert the list to pytorch tensor. dim = 0 and keepdim = True was explcitly written
        # a list of 2 tensors is returned, and not a list of 2 scalars.

In [36]:
class DiceLoss(nn.Module):
    def forward(self, classifications, regressions, anchors, annotations):
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[0, :, :]

        anchor_widths  = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):
            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            if bbox_annotation.shape[0] == 0:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(1).float().cuda())
                    classification_losses.append(torch.tensor(1).float().cuda())
                else:
                    regression_losses.append(torch.tensor(1).float())
                    classification_losses.append(torch.tensor(1).float())
                continue

            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4])
            IoU_max, IoU_argmax = torch.max(IoU, dim=1)

            # compute the loss for classification
            targets = torch.ones(classification.shape) * -1

            if torch.cuda.is_available():
                targets = targets.cuda()

            targets[torch.lt(IoU_max, 0.4), :] = 0

            positive_indices = torch.ge(IoU_max, 0.5)

            num_positive_anchors = positive_indices.sum()

            assigned_annotations = bbox_annotation[IoU_argmax, :]

            targets[positive_indices, :] = 0
            targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1

            t = annotations.shape[1]
            s = bbox_annotation.shape[0]
            q = classification.shape[0]
            r = classification.shape[1]
            
            true_labels = torch.zeros(q, r)
            if torch.cuda.is_available():
                true_labels = true_labels.cuda()
            for k in range(q):
                l = IoU_argmax[k]
                m = bbox_annotation[l, 4].long()
                true_labels[k, m] = 1

            numerator = torch.zeros(q)
            denominator = torch.zeros(q)
            for k in range(q):
                numerator[k] = torch.dot(classification[k], true_labels[k])
                denominator[k] = torch.sum(classification[k]) + torch.sum(true_labels[k])

            dice_loss = 1 - 2 * numerator / denominator
            classification_losses.append(dice_loss.mean())
            

            
            # compute the loss for regression

            if positive_indices.sum() > 0: # If atleast one element of IoU_max is greater than or equal to 0.5
                assigned_annotations = assigned_annotations[positive_indices, :] # Remove those predicted bounding
                # boxes whose corresponding IoU_max is less than 0.5. Let the new size be u*5. So u <= q.
                
                # Similarly do the above thing for predicted bounding boxes' dimensions and centre coordinates.
                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices] # These are arrays of size u.

                # Similarly we do above thing for ground truth bounding boxes
                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights # These are arrays of size u.

                # clip widths to 1
                gt_widths  = torch.clamp(gt_widths, min=1) 
                gt_heights = torch.clamp(gt_heights, min=1)

                # Loss terms:-
                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh)) # target becomes a matrix of
                # size 4 * u
                targets = targets.t() # it becomes a matrix of size u * 4

                if torch.cuda.is_available():
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else: # all elements of the 0_th and 1_st coloum of targets is divided by 0.1 whereas those of 2nd and
                # 3rd coloum are divided by 0.2.
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                negative_indices = 1 + (~positive_indices) # It is an array of size q whose i_th element os 1 if i_th
                # element of positive_indices is 1 else it is 2. ~ is used to perform element wise bitwise NOT

                regression_diff = torch.abs(targets - regression[positive_indices, :]) # regression_diff is a matrix
                # size u*4 whose i_th row gives the absolute difference between predicted and error values of 
                # dimensions and centre coordinates of bounding box

                regression_loss = torch.where( # regression loss is u*4 matrix whose (i, j)_th element will be
                    torch.le(regression_diff, 1.0 / 9.0), # 4.5 * regression_diff[i, j]^2 iff regression_diff[i, j]
                    0.5 * 9.0 * torch.pow(regression_diff, 2), # is less than 1/9, else it will be 
                    regression_diff - 0.5 / 9.0 # regression_diff[i, j] - 1/18. This is HUBER LOSS.
                )
            
                regression_losses.append(regression_loss.mean())
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())
        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)

In [44]:
class DiceLoss2(nn.Module):
    def forward(self, classifications, regressions, anchors, annotations):
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[0, :, :]

        anchor_widths  = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):
            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            if bbox_annotation.shape[0] == 0:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(1).float().cuda())
                    classification_losses.append(torch.tensor(1).float().cuda())
                else:
                    regression_losses.append(torch.tensor(1).float())
                    classification_losses.append(torch.tensor(1).float())
                continue

            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4])
            IoU_max, IoU_argmax = torch.max(IoU, dim=1)
            positive_indices = torch.ge(IoU_max, 0.5)
            assigned_annotations = bbox_annotation[IoU_argmax, :]
            
            # Compute true labels
            true_labels = torch.zeros_like(classification)
            if torch.cuda.is_available():
                true_labels = true_labels.cuda()
            true_labels[torch.arange(true_labels.shape[0]), bbox_annotation[IoU_argmax, 4].long()] = 1

            # Compute Dice loss
            numerator = torch.sum(classification * true_labels, dim=1)
            denominator = torch.sum(classification, dim=1) + torch.sum(true_labels, dim=1)
            dice_loss = 1 - 2 * numerator / denominator

            classification_losses.append(dice_loss.mean())

            # Compute regression loss (same as FocalLoss)
            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[positive_indices, :]
                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights

                gt_widths  = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh)).t()

                if torch.cuda.is_available():
                    targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else:
                    targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                regression_diff = torch.abs(targets - regression[positive_indices, :])
                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0
                )
                regression_losses.append(regression_loss.mean())
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())

        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)

In [None]:
class NIoULoss(nn.Module):
    def forward(self, classifications, regressions, anchors, annotations):
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[0, :, :]

        anchor_widths  = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):
            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            if bbox_annotation.shape[0] == 0:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(1).float().cuda())
                    classification_losses.append(torch.tensor(1).float().cuda())
                else:
                    regression_losses.append(torch.tensor(1).float())
                    classification_losses.append(torch.tensor(1).float())
                continue

            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4])
            IoU_max, IoU_argmax = torch.max(IoU, dim=1)
            positive_indices = torch.ge(IoU_max, 0.5)
            assigned_annotations = bbox_annotation[IoU_argmax, :]
            
            # Compute true labels
            true_labels = torch.zeros_like(classification)
            if torch.cuda.is_available():
                true_labels = true_labels.cuda()
            true_labels[torch.arange(true_labels.shape[0]), bbox_annotation[IoU_argmax, 4].long()] = 1

            # Compute Dice loss
            numerator = torch.sum(classification * true_labels, dim=1)
            denominator = torch.sum(classification, dim=1) + torch.sum(true_labels, dim=1)
            N = 2
            dice_loss = 1 - (N+1) * numerator / (denominator * N*numerator)

            classification_losses.append(dice_loss.mean())

            # Compute regression loss (same as FocalLoss)
            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[positive_indices, :]
                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights

                gt_widths  = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh)).t()

                if torch.cuda.is_available():
                    targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else:
                    targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                regression_diff = torch.abs(targets - regression[positive_indices, :])
                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0
                )
                regression_losses.append(regression_loss.mean())
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())

        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)