In [4]:
import numpy as np
import tensorflow as tf
from scipy.spatial.distance import cdist

In [51]:
a = tf.cast(tf.random.uniform((2, 5, 5), 0, 10, dtype=tf.int32), dtype=tf.float32)
b = tf.cast(tf.random.uniform((2, 5, 5), 0, 10, dtype=tf.int32), dtype=tf.float32)
a, b

(<tf.Tensor: shape=(2, 5, 5), dtype=float32, numpy=
 array([[[3., 1., 9., 8., 0.],
         [9., 4., 8., 1., 6.],
         [3., 2., 0., 3., 5.],
         [7., 4., 3., 8., 3.],
         [2., 9., 7., 2., 5.]],
 
        [[5., 1., 3., 7., 4.],
         [5., 9., 8., 7., 7.],
         [7., 0., 1., 9., 4.],
         [4., 7., 6., 8., 6.],
         [0., 4., 1., 6., 3.]]], dtype=float32)>,
 <tf.Tensor: shape=(2, 5, 5), dtype=float32, numpy=
 array([[[3., 4., 3., 9., 8.],
         [2., 0., 3., 9., 2.],
         [6., 6., 9., 5., 2.],
         [7., 3., 9., 3., 9.],
         [4., 3., 3., 1., 6.]],
 
        [[3., 8., 2., 8., 2.],
         [8., 7., 6., 5., 0.],
         [5., 2., 6., 9., 2.],
         [3., 9., 3., 2., 9.],
         [3., 3., 4., 9., 5.]]], dtype=float32)>)

In [52]:
def pairwise_dist (A, B):  
    """
    Computes pairwise distances between each elements of A and each elements of B.
    Args:
        A,    [b,m,d] matrix
        B,    [b,n,d] matrix
    Returns:
        D,    [b,m,n] matrix of pairwise distances
    """
    # squared norms of each row in A and B
    na = tf.reduce_sum(tf.square(A), 2)
    nb = tf.reduce_sum(tf.square(B), 2)

    # na as a row and nb as a co"lumn vectors
    na = tf.reshape(na, (tf.shape(na)[0], -1, 1))
    nb = tf.reshape(nb, (tf.shape(nb)[0], 1, -1))

    # return pairwise euclidead difference matrix
    D = tf.sqrt(tf.maximum(na - 2*tf.matmul(A, B, False, True) + nb, 0.0))
    return D

In [53]:
distances = pairwise_dist(a, b)
x = tf.cast(tf.tile(tf.reshape(tf.range(5), (1, 5, 1)), (1, 1, 5)), dtype=tf.float32)
y = tf.transpose(x, (0, 2, 1))
x, y

(<tf.Tensor: shape=(1, 5, 5), dtype=float32, numpy=
 array([[[0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.],
         [4., 4., 4., 4., 4.]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 5, 5), dtype=float32, numpy=
 array([[[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]]], dtype=float32)>)

In [55]:
distance_tuples = tf.reshape(tf.concat((
    tf.expand_dims(distances, axis=3),
    tf.repeat(tf.expand_dims(x, axis=3), len(distances), axis=0),
    tf.repeat(tf.expand_dims(tf.transpose(x, (0, 2, 1)), axis=3), len(distances), axis=0)
), axis=3), (len(distances), -1, 3))
distance_tuples

<tf.Tensor: shape=(2, 25, 3), dtype=float32, numpy=
array([[[10.488088 ,  0.       ,  0.       ],
        [ 6.5574384,  0.       ,  1.       ],
        [ 6.8556542,  0.       ,  2.       ],
        [11.224972 ,  0.       ,  3.       ],
        [11.224972 ,  0.       ,  4.       ],
        [11.357817 ,  1.       ,  0.       ],
        [13.038404 ,  1.       ,  1.       ],
        [ 6.7823296,  1.       ,  2.       ],
        [ 4.3588986,  1.       ,  3.       ],
        [ 7.1414285,  1.       ,  4.       ],
        [ 7.6157727,  2.       ,  0.       ],
        [ 7.6811457,  2.       ,  1.       ],
        [10.908711 ,  2.       ,  2.       ],
        [10.677078 ,  2.       ,  3.       ],
        [ 4.       ,  2.       ,  4.       ],
        [ 6.4807405,  3.       ,  0.       ],
        [ 6.5574384,  3.       ,  1.       ],
        [ 7.1414285,  3.       ,  2.       ],
        [ 9.899494 ,  3.       ,  3.       ],
        [ 8.246211 ,  3.       ,  4.       ],
        [10.       ,  4.    

In [56]:
tf.argsort(distance_tuples[:,:,0], axis=1)

<tf.Tensor: shape=(2, 25), dtype=int32, numpy=
array([[14,  8, 15,  1, 16,  7,  2, 22,  9, 17, 10, 24, 11, 19, 23, 18,
        20,  0, 13, 12,  3,  4,  5, 21,  6],
       [ 4,  2, 19, 20, 24, 15, 14, 12, 17,  8,  0, 18, 16,  9, 22,  5,
         6,  1,  7, 10, 23, 11, 21,  3, 13]], dtype=int32)>

In [64]:
distance_tuples = tf.gather(distance_tuples, tf.argsort(distance_tuples[:,:,0], stable=True, axis=1), batch_dims=1)
distance_tuples

<tf.Tensor: shape=(2, 25, 3), dtype=float32, numpy=
array([[[ 4.       ,  2.       ,  4.       ],
        [ 4.3588986,  1.       ,  3.       ],
        [ 6.4807405,  3.       ,  0.       ],
        [ 6.5574384,  0.       ,  1.       ],
        [ 6.5574384,  3.       ,  1.       ],
        [ 6.7823296,  1.       ,  2.       ],
        [ 6.8556542,  0.       ,  2.       ],
        [ 6.8556542,  4.       ,  2.       ],
        [ 7.1414285,  1.       ,  4.       ],
        [ 7.1414285,  3.       ,  2.       ],
        [ 7.6157727,  2.       ,  0.       ],
        [ 7.6157727,  4.       ,  4.       ],
        [ 7.6811457,  2.       ,  1.       ],
        [ 8.246211 ,  3.       ,  4.       ],
        [ 9.055385 ,  4.       ,  3.       ],
        [ 9.899494 ,  3.       ,  3.       ],
        [10.       ,  4.       ,  0.       ],
        [10.488088 ,  0.       ,  0.       ],
        [10.677078 ,  2.       ,  3.       ],
        [10.908711 ,  2.       ,  2.       ],
        [11.224972 ,  0.    

In [34]:
tf.gather(distance_tuples, tf.argsort(distance_tuples[:,:,0], axis=1))

<tf.Tensor: shape=(1, 25, 25, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]]], dtype

In [None]:
def emd():
def pairwise_dist(A, B):  
"""
Computes pairwise distances between each elements of A and each elements of B.
Args:
    A,    [b,m,d] matrix
    B,    [b,n,d] matrix
Returns:
    D,    [b,m,n] matrix of pairwise distances
"""
# squared norms of each row in A and B
na = tf.reduce_sum(tf.square(A), 2)
nb = tf.reduce_sum(tf.square(B), 2)

# na as a row and nb as a co"lumn vectors
na = tf.reshape(na, [tf.shape(na)[0], -1, 1])
nb = tf.reshape(nb, [tf.shape(nb)[0], 1, -1])

# return pairwise euclidead difference matrix
D = tf.sqrt(tf.maximum(na - 2*tf.matmul(A, B, False, True) + nb, 0.0))
return D

In [67]:
x = tf.cast(tf.tile(tf.reshape(tf.range(5), (1, 5, 1)), (2, 1, 5)), dtype=tf.float32)
# tf.transpose(x, (0, 2, 1))
x

<tf.Tensor: shape=(2, 5, 5), dtype=float32, numpy=
array([[[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.]],

       [[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.]]], dtype=float32)>

In [51]:
tf.transpose(x, (0, 2, 1))

<tf.Tensor: shape=(2, 5, 5), dtype=float32, numpy=
array([[[0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.]],

       [[0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.]]], dtype=float32)>

In [None]:
x = tf.cast(tf.tile(tf.reshape(tf.range(5), (1, 5, 1)), (2, 1, 5)), dtype=tf.float32)

In [54]:
pairs = tf.reshape(tf.concat((
    tf.expand_dims(distances, -1),
    tf.expand_dims(x, -1),
    tf.expand_dims(tf.transpose(x, (0, 2, 1)), -1),
), -1), (2, -1, 3))
pairs

<tf.Tensor: shape=(2, 25, 3), dtype=float32, numpy=
array([[[1.5659201 , 0.        , 0.        ],
        [8.760424  , 0.        , 1.        ],
        [2.3126996 , 0.        , 2.        ],
        [3.5797644 , 0.        , 3.        ],
        [5.9270597 , 0.        , 4.        ],
        [2.0533168 , 1.        , 0.        ],
        [2.8767252 , 1.        , 1.        ],
        [4.6224523 , 1.        , 2.        ],
        [6.0320663 , 1.        , 3.        ],
        [9.790033  , 1.        , 4.        ],
        [0.8180368 , 2.        , 0.        ],
        [0.84044576, 2.        , 1.        ],
        [7.638676  , 2.        , 2.        ],
        [1.2282348 , 2.        , 3.        ],
        [0.8541298 , 2.        , 4.        ],
        [4.454981  , 3.        , 0.        ],
        [2.3922575 , 3.        , 1.        ],
        [6.586767  , 3.        , 2.        ],
        [2.5384295 , 3.        , 3.        ],
        [3.2339394 , 3.        , 4.        ],
        [2.776121  , 4.     

In [78]:
a = np.random.normal((2, 5, 5))
b = np.random.normal((2, 5, 5))
cdist(a, b)

ValueError: XA must be a 2-dimensional array.

In [81]:
tf.repeat(tf.range(10), 10)

<tf.Tensor: shape=(100,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
       6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9], dtype=int32)>

In [122]:
class FastEmd(tf.keras.losses.Loss):
    def __init__(self, loss_fn=tf.keras.losses.mean_squared_error, **kwargs):
        super().__init__(**kwargs)
        self.loss_fn = loss_fn
    
    def _loss_np(self, y_true, y_pred):
        # num_batches = y_true.shape[0]
        sorted_indices = np.zeros(tf.shape(y_true)[:2])
        for batch_index, (a, b) in enumerate(zip(y_pred, y_true)):
            visited_a, visited_b = set(), set()
            distance = cdist(a, b).flatten()
            indices = np.argsort(distance)
            for i in indices:
                from_index, to_index = i // 5, i % 5
                if from_index in visited_a or to_index in visited_b:
                    continue
                sorted_indices[batch_index][len(visited_a)] += from_index
                visited_a.add(from_index)
                visited_b.add(to_index)
        return sorted_indices
    
    def call(self, y_true, y_pred, sample_weight=None):
        indices = tf.py_function(self._loss_np, (y_true, y_pred), tf.int32)
        y_pred = tf.gather(y_pred, indices, batch_dims=1)
        try:
            return self.loss_fn(y_true, y_pred, sample_weight)
        except:
            return self.loss_fn(y_true, y_pred)
        

In [123]:
loss_fn = FastEmd()
loss_fn(tf.random.normal((1, 5, 5)), tf.random.normal((1, 5, 5)))

[ 2 19 17 23 15  0 21 16  4 24  1 10 20  3 18  9  8 14 22  5 12  7  6 11
 13]


<tf.Tensor: shape=(), dtype=float32, numpy=1.9545559>

In [86]:
def earth_mover_np(distance_tuples):
    num_batches = distance_tuples.shape[0]
    indices = np.argsort(distance_tuples[:, :, 0], axis=1)
    total_distance = np.zeros(num_batches)
    for i, group in enumerate(indices):
        visited_from = set()
        visited_to = set()
        for j in group:
            distance, from_index, to_index = pairs[i,j].numpy()
            if from_index in visited_from or from_index in visited_to:
                continue
            visited_from.add(from_index)
            visited_from.add(to_index)
            total_distance[i] += distance
    return total_distance

In [71]:
import string

def pdist(a, b):
    """Pairwise Euclidean distances between vectors contained at the back of tensors.

    Uses expansion: (x - y)^T (x - y) = x^Tx - 2x^Ty + y^Ty 

    :param arr: (..., N, d) tensor
    :returns: (..., N, N) tensor of pairwise distances between vectors in the second-to-last dim.
    :rtype: tf.Tensor

    """
    shape = tuple(a.get_shape().as_list())
    rank_ = len(shape)
    N, d = shape[-2:]

    # Build a prefix from the array without the indices we'll use later.
    pref = string.ascii_lowercase[:rank_ - 2]

    # Outer product of points (..., N, N)
    xxT = tf.einsum('{0}ni,{0}mi->{0}nm'.format(pref), a, b)

    # Inner product of points. (..., N)
    xTx = tf.einsum('{0}ni,{0}ni->{0}n'.format(pref), a, b)

    # (..., N, N) inner products tiled.
    xTx_tile = tf.tile(xTx[..., None], (1,) * (rank_ - 1) + (N,))

    # Build the permuter. (sigh, no tf.swapaxes yet)
    permute = list(range(rank_))
    permute[-2], permute[-1] = permute[-1], permute[-2]

    # dists = (x^Tx - 2x^Ty + y^Tx)^(1/2). Note the axis swapping is necessary to 'pair' x^Tx and y^Ty
    return tf.sqrt(xTx_tile - 2 * xxT + tf.transpose(xTx_tile, permute))

In [73]:
a = tf.random.normal((2, 5, 5))
b = tf.random.normal((2, 5, 5))
pdist(a, b)

<tf.Tensor: shape=(2, 5, 5), dtype=float32, numpy=
array([[[0.0000000e+00,           nan,           nan, 3.4305346e+00,
                   nan],
        [1.6163690e+00, 0.0000000e+00,           nan, 4.4081535e+00,
                   nan],
        [          nan, 8.1776357e-01,           nan, 1.0487121e+00,
                   nan],
        [          nan,           nan, 3.0535650e+00, 0.0000000e+00,
         2.7163122e+00],
        [          nan,           nan,           nan,           nan,
         0.0000000e+00]],

       [[0.0000000e+00, 2.4481874e+00, 1.8555861e+00,           nan,
         2.9149864e+00],
        [          nan,           nan, 1.7503656e+00,           nan,
                   nan],
        [          nan, 2.5983584e+00, 6.9053395e-04, 2.0450380e+00,
         2.9445617e+00],
        [          nan,           nan, 1.3463805e+00,           nan,
         1.2723845e+00],
        [          nan, 1.1409109e+00, 2.5396392e+00,           nan,
         4.8828125e-04]]], dtype

In [65]:
def loss_fn(y_true, y_pred):
    batch_size = tf.shape(y_true)[0]
    set_size = tf.shape(y_true)[1]
    embed_dim = tf.shape(y_true)[2]
    # compute distance
    y = tf.cast(tf.tile(tf.reshape(tf.range(set_size), (1, set_size, 1)), (batch_size, 1, set_size)), dtype=tf.float32)
    x = tf.transpose(y, (0, 2, 1))
    pairs = tf.reshape(tf.concat((
        tf.expand_dims(distances, -1),
        tf.expand_dims(tf.transpose(x, (0, 2, 1)), -1),
        tf.expand_dims(x, -1)
    ), -1), (2, -1, 3))
    return tf.numpy_function(earth_mover_np, [pairs], tf.float32)

<tf.Tensor: shape=(2,), dtype=float64, numpy=array([ 8.03973198, 10.04707932])>

In [64]:
earth_mover_np(pairs.numpy())

array([ 8.03973198, 10.04707932])

In [56]:
indices = np.argsort(pairs[:,:,0], axis=1)
indices

array([[10, 11, 14, 13,  0,  5,  2, 16, 18, 20,  6, 19, 23,  3, 15,  7,
         4,  8, 17, 21, 12, 22,  1, 24,  9],
       [10, 24, 19, 23, 22,  1, 12, 18, 16,  4, 17,  9, 20,  5,  0, 21,
         2,  3, 15,  7, 13, 11,  8, 14,  6]])

In [59]:
total_distance = np.zeros((len(indices)))
for i, group in enumerate(indices):
    visited = set()
    for j in group:
        distance, from_index, to_index = pairs[i,j].numpy()
        if from_index in visited:
            continue
        visited.add(from_index)
        total_distance[i] += distance

In [61]:
total_distance

array([ 9.60565209, 13.14692855])

In [60]:
x = tf.tile(tf.reshape(tf.range(5), (1, 1, 5)), (1, 5, 1))
x

<tf.Tensor: shape=(1, 5, 5), dtype=int32, numpy=
array([[[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]]], dtype=int32)>

In [8]:
tf.expand_dims(distances, -1)

<tf.Tensor: shape=(5, 5, 1), dtype=float32, numpy=
array([[[9.010284  ],
        [8.313045  ],
        [7.7843547 ],
        [3.760599  ],
        [8.943699  ]],

       [[4.1826525 ],
        [7.7477193 ],
        [6.540223  ],
        [1.439954  ],
        [8.184491  ]],

       [[3.7914038 ],
        [9.168671  ],
        [4.7567797 ],
        [8.745359  ],
        [6.7313995 ]],

       [[0.3000903 ],
        [1.6757405 ],
        [0.25322914],
        [2.5159967 ],
        [3.6392224 ]],

       [[1.4888036 ],
        [1.345042  ],
        [2.8000808 ],
        [0.6844461 ],
        [7.1058025 ]]], dtype=float32)>

In [5]:
150*150

22500