## vmap

`vmap` is a function that allows you to automatically vectorize functions that take numpy arrays as inputs. It does this by transforming the function into a function that takes a batch of inputs and returns a batch of outputs. It does this by replacing every `for` loop in the function with a `vmap` function.

`in_axes` specify which axis of the positional argument need to be mapped over. Think of this as a batch dimension. 

output dimension analysis:
- applying `vmap` on a specific `in_axes`, means that axes is considered as a batch dimension. `vmap` will return a function that takes in a batch of inputs and returns a batch of outputs. The batch dimension is the first dimension of the output array.

Check the following examples.
 

In [32]:
import jax 
import jax.numpy as jnp
from jax import vmap 

In [42]:
def f(x):
    return x**2

x = 2*jnp.ones((10,3,4))

# apply vmap to f on axis 0 of x
print(vmap(f, in_axes=(0,))(x).shape)

# apply vmap to f on axis 1 of x -> now the output shape is (3, 10, 4), with original axis 1 as axis 0 due to batching.
print(vmap(f, in_axes=(1,))(x).shape)

# apply vmap twice on input x, axis 0 -> 10*3*4
# same as applying vmap once on axis 0 of x,
print(vmap(vmap(f, in_axes=(0,)), in_axes=(0,))(x).shape)

# apply vmap twice on input x, axis 1 and 0 
# note this only apply f once.
# the outer vmap is applied first -> lead to a shape of (3, 10, 4)
# the inner vmap is the applied only to (10,4) -> lead to a shape of (10, 4)
# the final output shape is (3, 10, 4)
print(vmap(vmap(f, in_axes=(0,)), in_axes=(1))(x).shape)


# apply vmap twice on input x, axis 2 and axis 1
# apply outer vmap first -> lead to a shape of (4, 10, 3)
# apply inner vmap to (10,3) -> lead to a shape of (3, 10)
# the final output shape is (4, 3, 10)
print(vmap(vmap(f, in_axes=(1,)), in_axes=(2,))(x).shape)

# apply vmap twice on input x, axis 1 and axis 2 
# raise error
# because the inner vmap has no axis 2 after the outer vmap is applied.
print(vmap(vmap(f, in_axes=(2,)), in_axes=(1,))(x).shape)

(10, 3, 4)
(3, 10, 4)
(10, 3, 4)
(3, 10, 4)
(4, 3, 10)


ValueError: vmap was requested to map its argument along axis 2, which implies that its rank should be at least 3, but is only 2 (its shape is (10, 4))

In [47]:
# Now let's try a function that does change the shape of the input:
# f1 sum over axis 0 of given array, which will remove the dimension of the array by 1
def f1(x):
    return jnp.sum(x, axis=0)

# apply vmap to f1 on axis 0 of x
print("from shape: ", x[0,:,:].shape, "->", f1(x[0,:,:]).shape)
print("from shape: ", x.shape, "->", vmap(f1, in_axes=(0,))(x).shape)

# apply vmap to f1 on axis 1 of x
print("from shape: ", x[1,:,:].shape, "->", f1(x[1,:,:]).shape)
print("from shape: ", x.shape, "->", vmap(f1, in_axes=(1,))(x).shape)

# apply vmap twice on axis 0 of x 
# apply outer vmap -> lead to a shape of (10, 3, 4)
# apply inner vmap to f with an input shape of (10,4) -> lead to a shape of (10,)
# the final output shape is (10, 3)
print("from shape: ", x.shape, "->", vmap(vmap(f1, in_axes=(0,)), in_axes=(0,))(x).shape)

# apply vmap twice on axis 1 and axis 0 of x
# apply outer vmap -> lead to a shape of (3, 10, 4)
# apply inner vmap to f with an input shape of (10,4) -> lead to a shape of (10,)
# the final output shape is (3, 10)
print(vmap(vmap(f1, in_axes=(0,)), in_axes=(1,))(x).shape)



from shape:  (3, 4) -> (4,)
from shape:  (10, 3, 4) -> (10, 4)
from shape:  (3, 4) -> (4,)
from shape:  (10, 3, 4) -> (3, 4)
from shape:  (10, 3, 4) -> (10, 3)
(3, 10)


## cdist

Pairwise distance between two arrays.

In [22]:

from scipy.spatial.distance import cdist

# vmap(func, in_axes, out_axes)
# in_axes specifies which axes to map over in the input arguments
# e.g., (0,1) means map the first axis of the first argument and the second axis of the second argument
# out_axes specifies which axes to map over in the output values

# vector product: x, y are (n,) vectors
vv = lambda x, y: jnp.vdot(x,y)

# matrix vector product: x is (n,m) matrix, y is (m,) vector -> (n,) vector 
mv = vmap(vv, (0, None), 0)

# matrix matrix product x is (n,m), y is (m,n) -> (n,n)
mm = vmap(mv, (None, 1), 1)

x = jnp.array([[1.,2.,3.],[4.,5.,6.]])
y = jnp.array([1.,2.,3.])

print(x.shape, y.shape)
print()
mm(x,x.T)

mv1 = vmap(vv, (0, 0), 0)
mv1(x, x)



(2, 3) (3,)



Array([14., 77.], dtype=float32)

Two ways of implement `cdist` in JAX:

In [23]:
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
Y = jnp.array([[7.0, 8.0], [9.0, 10.0]])

dist = cdist(X, Y, metric="euclidean")
print(dist)

def cdist_jax(X, Y):
    #out_shape = out_shape or (X.shape[0], Y.shape[0])
    #distances = jnp.zeros(out_shape)
    
    # Compute the squared Euclidean distances between all pairs of points
    # in X and Y.
    # The result is a matrix where the (i, j)-th entry is the squared
    # Euclidean distance between the i-th point in X and the j-th point in Y.
    squared_distances = jnp.sum(jnp.square(X[:, jnp.newaxis] - Y), axis=-1)
    
    # Take the square root to obtain the Euclidean distances.
    distances = jnp.sqrt(squared_distances)
    
    return distances

dist_1 = jax.jit(cdist_jax)(X,Y)
print(dist_1)


[[ 8.48528137 11.3137085 ]
 [ 5.65685425  8.48528137]
 [ 2.82842712  5.65685425]]
[[ 8.485281 11.313708]
 [ 5.656854  8.485281]
 [ 2.828427  5.656854]]


In [24]:
# Define a function that computes the Euclidean distance between two points.
def euclidean_distance(x, y):
    return jnp.sqrt(jnp.sum((x - y) ** 2))


# use vmap to vectorize 
# X-(m,d), Y->(n,d), f(X,Y)->(m,n)
# outer vmap goes over axis 0 of X, so that the batch dimension is mapped over, which is m
# inner vmap goes over axis 0 of Y, so that the batch dimension is mapped over, which is n
cdist_jax = vmap(vmap(euclidean_distance, in_axes=(None, 0)), in_axes=(0, None))

distances = cdist_jax(X, Y)
print(distances)
print(X.shape, Y.shape)

[[ 8.485281 11.313708]
 [ 5.656854  8.485281]
 [ 2.828427  5.656854]]
(3, 2) (2, 2)


## Broadcasting

- match dimensions of two arrays from the end to the beginning
- dimensions are compatible if:
    - they are equal
    - one of them is 1
    - if dimensions are not compatible, raise `ValueError`
- if one array has fewer dimensions than the other, prepend 1 to its shape until the number of dimensions matches
- if one dimension has length 1, the array is repeated along that dimension

In [25]:
a = jnp.array([0.0, 10.0, 20.0, 30.0])
b = jnp.array([1.0, 2.0, 3.0])

print(a.shape, b.shape)

# a+b will not work
# use newaxis to make (4,1) for a and then add together -> (4, 3)
print(a[:, jnp.newaxis] + b)

# use newaxis to make (3,1) for b and then add together -> (3, 4)
print(a + b[:, jnp.newaxis])

def add(x,y):
    return x+y

res = vmap(add, in_axes=(0, None))(a,b)
print(res.shape)

# 3-D array
print(X.shape, Y.shape)
print(X[:, jnp.newaxis,:].shape, Y.shape) 

print((X[:, jnp.newaxis] - Y).shape)



(4,) (3,)
[[ 1.  2.  3.]
 [11. 12. 13.]
 [21. 22. 23.]
 [31. 32. 33.]]
[[ 1. 11. 21. 31.]
 [ 2. 12. 22. 32.]
 [ 3. 13. 23. 33.]]
(4, 3)
(3, 2) (2, 2)
(3, 1, 2) (2, 2)
(3, 2, 2)
