In [7]:
import torch

kps_duplicate = torch.rand((2, 256, 7))
box_duplicate = torch.rand((4, 256, 7))

xmin = kps_duplicate[0].ge(box_duplicate[0])    # [hw, n]
ymin = kps_duplicate[1].ge(box_duplicate[1])
xmax = kps_duplicate[0].le(box_duplicate[2])
ymax = kps_duplicate[1].le(box_duplicate[3])

nbr_onehot = torch.mul(torch.mul(xmin, ymin), torch.mul(xmax, ymax)).t()   
n_neighbours = nbr_onehot.sum(dim=1)                                       
n_points = nbr_onehot.sum(dim=0)    

In [6]:
import torch
import torch.nn as nn
import numpy as np

feature_size = 16

def softmax_with_temperature(x, beta, d = 1):
    r'''SFNet: Learning Object-aware Semantic Flow (Lee et al.)'''
    M, _ = x.max(dim=d, keepdim=True)
    x = x - M # subtract maximum value for stability
    exp_x = torch.exp(x/beta)
    exp_x_sum = exp_x.sum(dim=d, keepdim=True)
    return exp_x / exp_x_sum

def soft_argmax(corr, beta=0.02):
    r'''SFNet: Learning Object-aware Semantic Flow (Lee et al.)
    corr : [B, l, hsws, htwt]
    '''
    b,_,h,w = corr.size()
    x_normal = np.linspace(-1,1,feature_size)
    x_normal = nn.Parameter(torch.tensor(x_normal, dtype=torch.float, requires_grad=False))
    y_normal = np.linspace(-1,1,feature_size)
    y_normal = nn.Parameter(torch.tensor(y_normal, dtype=torch.float, requires_grad=False))
    
    corr = softmax_with_temperature(corr, beta=beta, d=1)
    corr = corr.view(-1,h,w,h,w) # (target hxw) x (source hxw)

    grid_x = corr.sum(dim=1, keepdim=False) # marginalize to x-coord. [B, ws, ht, wt]
    x_normal = x_normal.expand(b,w)
    x_normal = x_normal.view(b,w,1,1)
    grid_x = (grid_x*x_normal).sum(dim=1, keepdim=True) # b x 1 x h x w
    
    grid_y = corr.sum(dim=2, keepdim=False) # marginalize to y-coord.
    y_normal = y_normal.expand(b,h)
    y_normal = y_normal.view(b,h,1,1)
    grid_y = (grid_y*y_normal).sum(dim=1, keepdim=True) # b x 1 x h x w

x = torch.rand((4, 256, feature_size, feature_size))

soft_argmax(x)