In [2]:
import torch

In [3]:
def gather_feature(fmap, index, mask=None, use_transform=False):
    if use_transform:
        # change a (N, C, H, W) tenor to (N, HxW, C) shape
        batch, channel = fmap.shape[:2]
        fmap = fmap.view(batch, channel, -1).permute((0, 2, 1)).contiguous()

    dim = fmap.size(-1)
    index = index.unsqueeze(len(index.shape)).expand(*index.shape, dim)
    fmap = fmap.gather(dim=1, index=index)
    if mask is not None:
        # this part is not called in Res18 dcn COCO
        mask = mask.unsqueeze(2).expand_as(fmap)
        fmap = fmap[mask]
        fmap = fmap.reshape(-1, dim)
    return fmap

def topKscoresPerBatch(fmaps, K=40):
    batch, channels, height, width = fmaps.shape

    # first we want the top K per fmaps
    flattened_hmaps = fmaps.reshape(batch, channels, -1)
    topk_scores_per_cls, topk_indices = torch.topk(flattened_hmaps, K)

    # computing x and y in (h, w) space
    topk_indices = topk_indices % (height * width)
    topk_x = (topk_indices / width).int().float()  
    topk_y = (topk_indices % width).int().float()

    # now we want the topk all classes merged, for each batch separatly
    flattened_cls = topk_scores_per_cls.reshape(batch, -1)
    topk_scores, indices = torch.topk(flattened_cls, K)

    topk_cls = (indices / K).int() # compute wich cls the topk belong

    # updating indices, x and y by matching indices and previous top40 for each hmaps
    topk_indices = gather_feature(topk_indices.view(batch, -1, 1), indices).reshape(batch, K)
    topk_x = gather_feature(topk_x.reshape(batch, -1, 1), indices).reshape(batch, K)
    topk_y = gather_feature(topk_y.reshape(batch, -1, 1), indices).reshape(batch, K)

    return topk_scores, topk_indices, topk_cls, topk_y, topk_x
# so this function is returning the top 40 activations all classes merged for every batch independantly. they are broke down into:
# the top 40 score sorted in descending order
# the top 40 indices in the (4, 53, 4096) list
# the top 40 score's x and y in the original features maps

In [4]:
fmap_test = torch.rand(4, 53, 64, 64)

topk_score, topk_inds, topk_clses, topk_ys, topk_xs = topKscoresPerBatch(fmap_test, K=40)

print(f'topk indices:\n{topk_inds[:, 0:6]}')


print('\n(x, y) in original heat maps features maps:')
for x, y in zip(topk_xs[0, 0:6], topk_ys[0, 0:6]):
    print(f'({x}, {y})')

topk indices:
tensor([[1683, 1601, 2587,  167, 1570, 1128],
        [3333, 3333, 2121, 1009, 1341, 2368],
        [2292, 3478, 1897, 3110, 2540, 1980],
        [1018, 3235,  897, 2463,  491, 3238]])

(x, y) in original heat maps features maps:
(26.0, 19.0)
(25.0, 1.0)
(40.0, 27.0)
(2.0, 39.0)
(24.0, 34.0)
(17.0, 40.0)


In [5]:
def gather_feature2(fmap, index, mask=None, use_transform=False):
    if use_transform:
        # change a (N, C, H, W) tenor to (N, HxW, C) shape
        batch, channel = fmap.shape[:2]
        fmap = fmap.view(batch, channel, -1).permute((0, 2, 1)).contiguous()
    dim = fmap.size(-1)

    index = index.unsqueeze(len(index.shape)).expand(*index.shape, dim)
    fmap = fmap.gather(dim=1, index=index)
    return fmap

pred_offset = torch.rand(4, 2, 64, 64)

pred_offset_gathered = gather_feature2(pred_offset, topk_inds, use_transform=True)

print(pred_offset_gathered.shape, topk_inds.shape)



torch.Size([4, 40, 2]) torch.Size([4, 40])


In [6]:
test = torch.rand(4, 40)
test = test.unsqueeze(len(test.shape))

print(test.expand(4, 40, 2).shape)


torch.Size([4, 40, 2])


In [10]:
# Create a tensor with size (2, 3)
tensor = torch.tensor([1, 2, 3, 4])

print(tensor)
print(torch.sort(tensor, descending=True))

tensor([1, 2, 3, 4])
torch.return_types.sort(
values=tensor([4, 3, 2, 1]),
indices=tensor([3, 2, 1, 0]))
