In [2]:
import jax 
import jax.numpy as jnp
from scipy.spatial.distance import cdist

# vmap(func, in_axes, out_axes)
# in_axes specifies which axes to map over in the input arguments
# e.g., (0,1) means map the first axis of the first argument and the second axis of the second argument
# out_axes specifies which axes to map over in the output values

# vector product: x, y are (n,) vectors
vv = lambda x, y: jnp.vdot(x,y)

# matrix vector product: x is (n,m) matrix, y is (m,) vector -> (n,) vector 
mv = jax.vmap(vv, (0, None), 0)

# matrix matrix product x is (n,m), y is (m,n) -> (n,n)
mm = jax.vmap(mv, (None, 1), 1)

x = jnp.array([[1.,2.,3.],[4.,5.,6.]])
y = jnp.array([1.,2.,3.])

print(x.shape, y.shape)
print()
mm(x,x.T)

mv1 = jax.vmap(vv, (0, 0), 0)
mv1(x, x)



2023-05-05 23:41:36.497269: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 25420759040
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(2, 3) (3,)



Array([14., 77.], dtype=float32)

Two ways of implement `cdist` in JAX:

In [6]:
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
Y = jnp.array([[7.0, 8.0], [9.0, 10.0]])

dist = cdist(X, Y, metric="euclidean")
print(dist)

def cdist_jax(X, Y):
    #out_shape = out_shape or (X.shape[0], Y.shape[0])
    #distances = jnp.zeros(out_shape)
    
    # Compute the squared Euclidean distances between all pairs of points
    # in X and Y.
    # The result is a matrix where the (i, j)-th entry is the squared
    # Euclidean distance between the i-th point in X and the j-th point in Y.
    squared_distances = jnp.sum(jnp.square(X[:, jnp.newaxis] - Y), axis=-1)
    
    # Take the square root to obtain the Euclidean distances.
    distances = jnp.sqrt(squared_distances)
    
    return distances

dist_1 = jax.jit(cdist_jax)(X,Y)
print(dist_1)


[[ 8.48528137 11.3137085 ]
 [ 5.65685425  8.48528137]
 [ 2.82842712  5.65685425]]
[[ 8.485281 11.313708]
 [ 5.656854  8.485281]
 [ 2.828427  5.656854]]


In [None]:
# Define a function that computes the Euclidean distance between two points.
def euclidean_distance(x, y):
    return jnp.sqrt(jnp.sum((x - y) ** 2))

# Use vmap to vectorize the euclidean_distance function along the first axis
# of both input arrays. This allows us to compute pairwise distances between
# all points in X and Y.
cdist_jax = jax.vmap(jax.vmap(euclidean_distance, in_axes=(None, 0)), in_axes=(0, None))

distances = cdist_jax(X, Y)
print(distances)
print(X.shape, Y.shape)

[[ 8.485281 11.313708]
 [ 5.656854  8.485281]
 [ 2.828427  5.656854]]
(3, 2) (2, 2)
