In [1]:
import torch

In [52]:
def batch_euclidean_dist(x, y):
    assert len(x.size()) == 3
    assert len(y.size()) == 3
    assert x.size(0) == y.size(0)
    assert x.size(-1) == y.size(-1)
    N, m, d = x.size()
    N, n, d = y.size()
    #经过计算后xx与yy维度都是[N, m, n]
    xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n)
    yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1)
    dist = xx + yy
    dist.baddbmm_(1, -2, x, y.permute(0, 2, 1))
    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
    return dist

def shortest_dist(dist_mat):
    m, n = dist_mat.size()[:2]
    dist = [[0 for _ in range(n)] for _ in range(m)]
    
    for i in range(m):
        for j in range(n):
            if (i == 0) and (j == 0):
                dist[i][j] = dist_mat[i, j]
            elif (i == 0) and (j > 0):
                dist[i][j] = dist[i][j - 1] + dist_mat[i, j]
            elif (i > 0) and (j == 0):
                dist[i][j] = dist[i - 1][j] + dist_mat[i, j]
            else:
                dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j]
    dist = dist[-1][-1]
    print(dist.size())
    return dist

def batch_local_dist(x, y):
    assert len(x.size()) == 3
    assert len(y.size()) == 3
    assert x.size(0) == y.size(0)
    assert x.size(-1) == y.size(-1)
    # shape [N, m, n]
    dist_mat = batch_euclidean_dist(x, y)
    #dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)
    # shape [N]
    dist = shortest_dist(dist_mat.permute(1, 2, 0))
    return dist

In [53]:
if __name__=='__main__':
    x=torch.randn(16,128,64)
    y=torch.randn(16,32,64)
    local_dist=batch_local_dist(x,y)
    print(local_dist)

[tensor([10.6438, 10.6118, 10.8766, 11.0050, 11.6466, 10.9231,  9.6785, 12.2424,
        10.9268,  9.7698, 10.5586,  9.7058, 13.8538, 10.3351, 12.8878, 11.7036]), tensor([20.7903, 21.4806, 22.4499, 21.1098, 21.8715, 22.1457, 21.3849, 25.2456,
        21.0856, 21.1540, 23.3751, 20.7678, 27.1834, 20.5455, 26.6030, 23.7849]), tensor([30.9046, 33.4866, 32.1727, 31.1313, 33.4771, 33.2062, 31.0123, 36.4199,
        32.5139, 33.0223, 33.4546, 32.1222, 38.5284, 31.5332, 40.4708, 35.1577]), tensor([42.0829, 44.6508, 42.7554, 41.6507, 45.6726, 44.7948, 42.3001, 49.2751,
        41.9955, 43.2551, 43.6701, 43.6411, 50.7652, 42.4090, 52.2902, 46.3386]), tensor([52.0616, 55.6587, 54.2051, 52.2132, 56.4159, 56.3671, 53.2002, 61.7319,
        52.3730, 55.2717, 53.6641, 53.6339, 63.5646, 52.0466, 63.0945, 58.4711]), tensor([61.7067, 67.3384, 65.6981, 62.1535, 67.6708, 67.3262, 64.0027, 73.2907,
        64.1994, 66.4494, 63.4368, 65.3999, 75.9325, 64.4950, 75.6971, 71.6683]), tensor([72.5116, 77.6823, 7