In [None]:
import jax
import jax.numpy as jnp

## Manual Vectorization

In [None]:
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

In [None]:
def convolve(x, w):
  output = []
  for i in  range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

In [None]:
convolve(x, w)

In [None]:
xs = jnp.stack([x,x])
ws = jnp.stack([w,w])

In [None]:
print('xs')
print(xs)

print('ws')
print(ws)

In [None]:
def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

In [None]:
manually_batched_convolve(xs, ws) #This produces the correct result, however it is not very efficient.

In [None]:
#Other way
def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)

JAX provides other better way

## Automatic Vectorization

In [None]:
auto_batch__convolve = jax.vmap(convolve)

In [None]:
auto_batch__convolve(xs, ws)

In [None]:
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)

jax.vmap also supports the case where only one of the arguments is batched


In [None]:
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)

## Combining transformations

In [None]:
jitted_batch_convolve = jax.jit(auto_batch__convolve)

In [None]:
jitted_batch_convolve(xs, ws)