In [83]:
import torch
import torch.nn as nn
from torchvision.ops import generalized_box_iou
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from typing import Tuple

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [84]:
def cxcywh_to_xyxy(boxes):
    cx, cy, w, h = boxes.unbind(-1)
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    return torch.stack((x1, y1, x2, y2), dim=-1)

In [85]:
class HungarianMatcher(nn.Module):
    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 2):
        super(HungarianMatcher, self).__init__()

        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou

    def _get_index_map(self, targets):
        batch_idx = torch.cat([torch.full((len(t['labels']),), i, dtype=torch.int64)
                               for i, t in enumerate(targets)])
        gt_idx = torch.cat([torch.arange(len(t['labels']), dtype=torch.int64)
                            for t in targets])
        return batch_idx, gt_idx

    #Manhattan distance between predicted and GTBoxes

    def _bbox_distance(self, pred_boxes, targets):
        boxes = torch.cat([t['boxes'] for t in targets])
        pred_boxes = pred_boxes.view(-1, 4)
        cost = torch.cdist(pred_boxes, boxes, p=1)
        return cost

    def _giou_loss(self, pred_boxes, targets):
        all_cost = []

        for i, t in enumerate(targets):
            preds = pred_boxes[i]
            gts = t['boxes']

            if len(gts) == 0:
                all_cost.append(torch.zeros(preds.size(0), device=preds.device))
                continue
            pred_xyxy = cxcywh_to_xyxy(preds)
            targetxyxy = cxcywh_to_xyxy(gts)

            giou = generalized_box_iou(pred_xyxy, targetxyxy)
            best_giou_per_pred = giou.max(dim=-1)[0]
            all_cost.append(-best_giou_per_pred)
        return all_cost

    @torch.no_grad()
    def forward(self, outputs, targets):
        bs, num_queries = outputs['pred_logits'].shape[:2]
        indices = []

        for i in range(bs):
            # --- extract predictions for this image ---
            out_prob = outputs['pred_logits'][i].softmax(-1)  # [N, C]
            out_bbox = outputs['pred_boxes'][i]  # [N, 4]

            tgt_ids = targets[i]['labels']  # [K]
            tgt_bbox = targets[i]['boxes']  # [K, 4]

            # --- compute costs ---
            cost_class = -out_prob[:, tgt_ids]  # [N, K]
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)  # [N, K]
            cost_giou = -generalized_box_iou(
            cxcywh_to_xyxy(out_bbox), cxcywh_to_xyxy(tgt_bbox)
            )  # [N, K]

            # --- total cost ---
            C = (self.cost_class * cost_class +
                 self.cost_bbox * cost_bbox +
                 self.cost_giou * cost_giou)

            i_idx, j_idx = linear_sum_assignment(C.cpu())
            indices.append((
                torch.as_tensor(i_idx, dtype=torch.int64),
                torch.as_tensor(j_idx, dtype=torch.int64)
            ))

        return indices



In [93]:
#Understand the Set Criterion
class SetCriterion(nn.Module):
    def __init__(self, matcher, weight_dict, eos_coef=0.1, losses=None):
        super(SetCriterion, self).__init__()
        if losses is None:
            losses = ['labels', 'boxes']
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses

    def _get_src_batch_indices(self, indices):
        batch_idx = torch.cat(
            [torch.full((len(src),), i, dtype=torch.int64)
             for i, (src, _) in enumerate(indices)]
        )
        src_idx = torch.cat(
            [src for (src, _) in indices]
        )
        return batch_idx, src_idx

    def loss_labels(self, outputs, targets, indices, num_boxes):
        """Compute classification loss for matched predictions only"""
        batch_idx, src_idx = self._get_src_batch_indices(indices)
        pred_logits = outputs['pred_logits'][batch_idx, src_idx]
        #Ground truth label for the same box
        target_classes = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
        #Compute loss function
        loss_ce = F.cross_entropy(pred_logits, target_classes, reduction='none')
        weights = torch.ones_like(target_classes, dtype=torch.float)
        loss = (loss_ce * weights).sum() / num_boxes
        return {'loss_ce': loss}

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute DETR box regress loss"""
        batch_idx, src_idx =self._get_src_batch_indices(indices)
        src_boxes = outputs['pred_boxes'][batch_idx, src_idx]

        #Get each box ground truth
        target_boxes = torch.cat(
            [t['boxes'][J] for t, (_, J) in zip(targets, indices)],
        )

         # L1 loss   numeric distance between boxes
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        loss_bbox = loss_bbox.sum()/ num_boxes

        #GIOU Loss (Overlapping Boxes)
        src_xyxy = cxcywh_to_xyxy(src_boxes)
        target_xyxy = cxcywh_to_xyxy(target_boxes)

        giou = generalized_box_iou(src_xyxy, target_xyxy)
        giou = torch.nan_to_num(giou, nan= 0.0, posinf=0.0, neginf=-1.0)
        loss_giou = (1 - torch.diag(giou)).sum() / num_boxes

        return {'loss_box':loss_bbox, 'loss_giou': loss_giou}


    def forward(self, outputs, targets ):
        """Compute the total DETR LOSS"""
        indices = self.matcher(outputs, targets)
        num_boxes = sum(len(t['labels']) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device = next(iter(outputs.values())).device)
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            torch.distributed.all_reduce(num_boxes)
            num_boxes = num_boxes / torch.distributed.get_world_size()
        num_boxes = max(num_boxes.item(), 1.0)

        losses ={}

        for loss in self.losses:
            if loss == 'labels':
                losses.update(self.loss_labels(outputs, targets, indices, num_boxes))
            elif loss == 'boxes':
                losses.update(self.loss_boxes(outputs, targets, indices, num_boxes))
        total_loss = sum(self.weight_dict[k] * v for k, v in losses.items() if k in self.weight_dict)
        return total_loss, losses

In [94]:
outputs = {
    "pred_logits": torch.randn(2, 5, 91),  # 2 images, 5 queries each
    "pred_boxes": torch.rand(2, 5, 4)
}
targets = [
    {"labels": torch.tensor([3, 5, 7]), "boxes": torch.rand(3, 4)},
    {"labels": torch.tensor([4, 8]), "boxes": torch.rand(2, 4)}
]
matcher = HungarianMatcher()
indices = matcher(outputs, targets)
print(indices)


[(tensor([2, 3, 4]), tensor([0, 2, 1])), (tensor([1, 4]), tensor([1, 0]))]


In [95]:
weight_dict = {'loss_ce': 1.0, 'loss_bbox': 5.0, 'loss_giou': 2.0}
outputs = {
    "pred_logits": torch.randn(2, 5, 91),
    "pred_boxes": torch.rand(2, 5, 4)
}
targets = [
    {"labels": torch.tensor([3, 5, 7]), "boxes": torch.rand(3, 4)},
    {"labels": torch.tensor([4, 8]), "boxes": torch.rand(2, 4)}
]

matcher = HungarianMatcher()
criterion = SetCriterion(matcher, weight_dict, eos_coef=0.1, losses=['labels', 'boxes'])

total_loss, loss_dict = criterion(outputs, targets)
print("Matched indices:", matcher(outputs, targets))
print("Total loss:", total_loss)
print("Loss breakdown:", loss_dict)


Matched indices: [(tensor([0, 1, 3]), tensor([0, 1, 2])), (tensor([3, 4]), tensor([1, 0]))]
Total loss: tensor(7.8351)
Loss breakdown: {'loss_ce': tensor(5.7149), 'loss_box': tensor(0.6879), 'loss_giou': tensor(1.0601)}
