In [1]:
import torch
from scipy.optimize import linear_sum_assignment

In [2]:
out_box = torch.ones((5*100, 4))
out_box[:30] *= 5.
target_box = torch.ones((5*25, 4)) * 5.

In [3]:
out_cls = torch.randn((5*100, 2)).softmax(-1)
target_cls = torch.ones((5*25), dtype=torch.int)
out_cls.shape

torch.Size([500, 2])

In [4]:
out_box.shape

torch.Size([500, 4])

In [5]:
cost_bbox = torch.cdist(out_box, target_box, p=1)
cost_bbox.shape

torch.Size([500, 125])

In [6]:
cost_class = -out_cls[:, target_cls]
cost_class.shape

torch.Size([500, 125])

In [7]:
print(f"cost_box: {cost_bbox.shape}\ncost_class: {cost_class.shape}")

cost_box: torch.Size([500, 125])
cost_class: torch.Size([500, 125])


In [8]:
C = cost_class + cost_bbox
C.shape

torch.Size([500, 125])

In [9]:
C = cost_class + cost_bbox
C = C.view(5, 100, -1)
C.shape

torch.Size([5, 100, 125])

In [10]:
len(torch.randn((5,4)))

5

In [11]:
sizes = [25 for _ in range(5)]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

In [12]:
out = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

In [13]:
out[0]

(tensor([ 0,  1,  2,  3,  4,  5,  8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21,
         22, 23, 24, 25, 26, 27, 28]),
 tensor([ 4,  0, 23, 21,  9, 19,  3, 13, 24,  5,  7, 17, 22,  2,  8, 15,  6, 11,
         12, 10, 18,  1, 20, 14, 16]))

In [14]:
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(out)])
print(f'{batch_idx}\n{batch_idx.shape}') 

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4])
torch.Size([125])


In [20]:
i = 0
src = out[0][0]
print(f'{src.shape}\n{src}')

torch.Size([25])
tensor([ 0,  1,  2,  3,  4,  5,  8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21,
        22, 23, 24, 25, 26, 27, 28])


In [21]:
torch.full_like(src, i)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])

In [15]:
src_idx = torch.cat([src for (src, _) in out])
print(f'{src_idx}\n{src_idx.shape}')

tensor([ 0,  1,  2,  3,  4,  5,  8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21,
        22, 23, 24, 25, 26, 27, 28,  0,  2,  3, 12, 15, 17, 20, 22, 25, 36, 37,
        38, 41, 50, 52, 60, 62, 63, 69, 71, 73, 74, 78, 86, 94,  3,  6, 20, 21,
        23, 27, 30, 35, 39, 46, 52, 53, 54, 56, 58, 62, 66, 67, 74, 75, 76, 79,
        89, 93, 95,  1,  6,  7, 16, 17, 18, 21, 24, 35, 41, 44, 48, 49, 51, 52,
        53, 55, 60, 72, 75, 76, 80, 87, 92, 95,  1,  4, 11, 12, 17, 18, 25, 29,
        36, 40, 44, 48, 51, 64, 68, 71, 72, 73, 75, 78, 84, 86, 89, 90, 96])
torch.Size([125])


In [35]:
out_bbox = torch.randn((5,100,4))
t_bbox = torch.randn((5, 25, 4))

In [36]:
src_bbox = out_bbox[(batch_idx, src_idx)]
target_bbox = torch.cat([t[i] for t, (_, i) in zip(t_bbox, out)], dim=0)

In [33]:
src_bbox.shape

torch.Size([125, 4])

In [38]:
target_bbox.shape

torch.Size([125, 4])

In [39]:
loss = torch.nn.functional.l1_loss(src_bbox, target_bbox, reduction='none')

In [43]:
loss.sum() / 25

tensor(23.5530)