In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import random, jit, lax, vmap, pmap, grad, value_and_grad

In [2]:
rng = np.random.default_rng()
key = jax.random.PRNGKey(0)

## Automatic Vectorization
Define the op for a "single" instance, and then use `vmap` to call it with batches of input. `vmap` is usually an order of magnitude faster than naive for-loop through the batch. It seems to me that automatic vectorization will use threads/warps on a single accelerator to do DDP with the input. It is not doing any clever vector math.

Update: I recdently saw a [Twitter thread](https://twitter.com/jakevdp/status/1612544608646606849?s=20&t=s8uBb3teX19T3ELQ07melA) that seemed to imply that `vmap` does indeed do some clever linalg. I can also verify this by examining the jaxpr of the vmapped function.

`jnp.dot` is defined for two 1D vectors of the same size. It will not work with 2D vectors.

In [16]:
x = jax.random.normal(key, shape=(100,), dtype=jnp.float32)
y = jax.random.normal(key, shape=(100,), dtype=jnp.float32)
jnp.dot(x, y)

Array(89.16817, dtype=float32)

In [17]:
batch_x = jax.random.normal(key, shape=(5,100), dtype=jnp.float32)
batch_y = jax.random.normal(key, shape=(5,100), dtype=jnp.float32)
try:
    jnp.dot(batch_x, batch_y)
except Exception as err:
    print(f"{type(err)}\n{err}")

<class 'TypeError'>
Incompatible shapes for dot: got (5, 100) and (5, 100).


But I can `vmap` it so that it will work with one extra dimension.

In [18]:
vdot = vmap(jnp.dot)
vdot(batch_x, batch_y)

Array([103.42397 , 116.75762 ,  97.10165 ,  85.551155,  93.10307 ],      dtype=float32)

And this is way faster than a naive for-loop implementation. The naive for-loop implementation cannot be jitted because its calculation depends on the size of the input.

In [19]:
def naive_vdot(batch_x, batch_y):
    m = len(batch_x)
    dots = []
    for idx in range(m):
        dot = jnp.dot(batch_x[idx], batch_y[idx])
        dots.append(dot)
    return jnp.array(dots)


In [20]:
%timeit naive_vdot(batch_x, batch_y)

1.78 ms ± 37 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [21]:
# Compare it with a non-jitted vmapped function
%timeit vdot(batch_x, batch_y)

150 µs ± 19.5 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [22]:
%timeit jit(vdot)(batch_x, batch_y)

111 µs ± 34.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
jax.make_jaxpr(vdot)(batch_x, batch_y)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[5,100][39m b[35m:f32[5,100][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[5][39m = dot_general[
      dimension_numbers=(((1,), (1,)), ((0,), (0,)))
      precision=None
      preferred_element_type=None
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [25]:
jax.make_jaxpr(jnp.dot)(x, y)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[100][39m b[35m:f32[100][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = dot_general[
      dimension_numbers=(((0,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

By default `vmap` assumes that the first dimension of the input tensors is the batch dimensions. We can also explicitly specify this using the `in_axes` argument. This is useful when some of the inputs are batched and others are not.

In [10]:
vdot2 = vmap(jnp.dot, in_axes=(None, 0))
vdot2(x, batch_y)

Array([ 4.997196  ,  0.39028692, -0.4550171 ,  5.45161   , -1.3241603 ],      dtype=float32)

Let me define a simple perceptron with this -

In [11]:
def affine(W, b, x):
    return W @ x + b

In [12]:
# input has 20 features and output has a dim of 10 (classes?)
W = jax.random.normal(key, shape=(10, 20), dtype=jnp.float32)
b = jax.random.normal(key, shape=(10,), dtype=jnp.float32)
x = jax.random.normal(key, shape=(20,), dtype=jnp.float32)
affine(W, b, x)

Array([-9.891567  ,  3.198782  , -5.058924  ,  5.849048  , -1.042989  ,
        6.441183  , -3.0672126 , -2.5102026 , -0.10067141,  0.2673887 ],      dtype=float32)

In [13]:
# now with a batch size of 3
batch_x = jax.random.normal(key, shape=(3, 20), dtype=jnp.float32)
try:
    affine(W, b, batch_x)
except Exception as err:
    print(f"{type(err)}\n{err}")

<class 'TypeError'>
dot_general requires contracting dimensions to have the same shape, got (20,) and (3,).


In [14]:
vmap(affine, in_axes=(None, None, 0))(W, b, batch_x)

Array([[ 2.4496658 , -6.5097513 ,  7.7877793 , -2.5308816 ,  4.5156784 ,
         8.680212  , -0.36133832, -8.290409  ,  5.073617  ,  6.4655223 ],
       [-1.755568  , -0.21376276,  2.8784447 ,  0.7835917 , -7.0984244 ,
        -4.86695   , -4.7300005 , -8.867712  ,  2.7317257 , -0.13192517],
       [-3.3044035 ,  2.625197  , -1.6779847 ,  2.484176  , -3.8615348 ,
         8.243641  ,  0.31611758,  6.9084396 , -0.9615234 , -4.7001905 ]],      dtype=float32)