In [1]:
import jax
import jax.numpy as jnp

# vmap

In [11]:
def dot(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.vdot(x, y)

## Vector -Vector dot product

(v,) . (u,) => ()

In [15]:
dot(jnp.ones(shape=(3,)),jnp.ones(shape=(3,)))

()

## List of Vector dot product

(b,v) . (b,u) => (b,)

### 🚨📢⚠️ Pit Fall

In [18]:
res = dot(jnp.ones(shape=(10,3)),jnp.ones(shape=(10,3)))
res

Array(30., dtype=float32)

In [20]:
res.shape #! this is scalar 
#! Why ? vdot do first flatten the input then do the dot product
#* (10,3) . (10,3) ==> (30,) . (30,) ==> ()

()

In [24]:
input_shape = jax.ShapeDtypeStruct(shape=(10,3),dtype=jnp.float32)
jax.make_jaxpr(dot)(input_shape,input_shape)

{ lambda ; a:f32[10,3] b:f32[10,3]. let
    c:f32[30] = reshape[dimensions=None new_sizes=(30,)] a
    d:f32[30] = reshape[dimensions=None new_sizes=(30,)] b
    e:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

### Manual vectorize

In [30]:
def dot_vectorize(x:jax.Array,y:jax.Array):
    return jnp.einsum("ij,ij->i",x,y)

In [32]:
res = dot_vectorize(jnp.ones(shape=(10,3)),jnp.ones(shape=(10,3)))
res

Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32)

In [34]:
res.shape

(10,)

### Auto Vectorize

In [35]:
dot_vmapped = jax.vmap(dot)
res = dot_vmapped(jnp.ones(shape=(10,3)),jnp.ones(shape=(10,3)))
res

Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32)

In [36]:
res.shape

(10,)

### speed comparison

In [41]:
%timeit dot_vectorize(jnp.ones(shape=(10,3)),jnp.ones(shape=(10,3))).block_until_ready()

221 µs ± 2.56 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [42]:
%timeit dot_vmapped(jnp.ones(shape=(10,3)),jnp.ones(shape=(10,3))).block_until_ready()

438 µs ± 622 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [43]:
dot_vectorize_jit = jax.jit(dot_vectorize)
dot_vmapped_jit = jax.jit(dot_vmapped)

In [44]:
%timeit dot_vectorize_jit(jnp.ones(shape=(10,3)),jnp.ones(shape=(10,3))).block_until_ready()

140 µs ± 3.44 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [45]:
%timeit dot_vmapped_jit(jnp.ones(shape=(10,3)),jnp.ones(shape=(10,3))).block_until_ready()

139 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### Source comparison

In [46]:
jax.make_jaxpr(dot_vectorize)(input_shape,input_shape)

{ lambda ; a:f32[10,3] b:f32[10,3]. let
    c:f32[10] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] a b
  in (c,) }

In [47]:
jax.make_jaxpr(dot_vmapped)(input_shape,input_shape)

{ lambda ; a:f32[10,3] b:f32[10,3]. let
    c:f32[10] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] a b
  in (c,) }