In [3]:
import torch
from scipy.optimize import linear_sum_assignment

In [14]:
def hungarian(true, pred):
    true_boxes = true['boxes']
    true_labels = true['labels']
    boxes = pred['boxes']
    labels = pred['labels']
    # Calculamos el costo de emparejamiento entre las cajas predichas y verdaderas
    cost_boxes = torch.cdist(boxes, true_boxes, p=1)
    
    # Calculamos el costo de emparejamiento entre las etiquetas predichas y verdaderas
    cost_labels = torch.cdist(labels, true_labels, p=1)
    
    # Combinamos los costos
    cost = cost_boxes + cost_labels
    
    # Resolvemos el problema de asignación lineal
    row_ind, col_ind = linear_sum_assignment(cost.cpu().detach().numpy())
    
    # Obtenemos las cajas y etiquetas predichas emparejadas
    boxes = boxes[row_ind]
    labels = labels[row_ind]
    
    # Obtenemos las cajas y etiquetas verdaderas emparejadas
    true_boxes = true_boxes[col_ind]
    true_labels = true_labels[col_ind]
    
    return {'boxes': boxes, 'labels': labels}, {'boxes': true_boxes, 'labels': true_labels}

In [25]:
true = {'boxes': torch.tensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.5, 0.5, 0.6, 0.6]]), 'labels': torch.tensor([1.0, 0.0, 0.0]).unsqueeze(1)}
preds = {'boxes': torch.tensor([[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.5, 0.5, 0.6, 0.6]]), 'labels': torch.tensor([1.0, 0.0, 0.75, 0.0]).unsqueeze(1)}
hungarian(true, preds)

({'boxes': tensor([[0.1000, 0.1000, 0.2000, 0.2000],
          [0.3000, 0.3000, 0.4000, 0.4000],
          [0.5000, 0.5000, 0.6000, 0.6000]]),
  'labels': tensor([[1.0000],
          [0.7500],
          [0.0000]])},
 {'boxes': tensor([[0.1000, 0.1000, 0.2000, 0.2000],
          [0.3000, 0.3000, 0.4000, 0.4000],
          [0.5000, 0.5000, 0.6000, 0.6000]]),
  'labels': tensor([[1.],
          [0.],
          [0.]])})

In [21]:
# calculate l1 loss, iou loss, and classification loss with torch
def hungarian_loss(true, pred, l1_weight=1.0, iou_weight=1.0, classification_weight=1.0):
    pred, true = hungarian(true, pred)
    l1 = torch.nn.functional.l1_loss(pred['boxes'], true['boxes'])
    iou = 1 - torch.diag(torchvision.ops.box_iou(pred['boxes'], true['boxes']))
    iou = torch.mean(iou)
    classification = torch.nn.functional.binary_cross_entropy_with_logits(pred['labels'], true['labels'])
    return l1_weight * l1 + iou_weight * iou + classification_weight * classification


torch.Size([3, 1])