In [1]:
# Vectorization is another one of jax transformations.

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [3]:
import jax
import jax.numpy as jnp

## Manual Vectorization

In [5]:
x = jnp.arange(5)
w = jnp.array([2.,3.,4.])

In [6]:
def convolve(x, w):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i-1: i+2], w))
    return jnp.array(output)

In [8]:
convolve(x, w)

Array([11., 20., 29.], dtype=float32)

In [9]:
# Suppose we want to apply this function to a batch of weights to a batch of vectors

In [10]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

In [11]:
xs.shape

(2, 5)

In [12]:
# Naive way is to simply loop

In [13]:
def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [14]:
def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

## Automatic vectorization

jax.vmap() is designed to generate such a vectorized implementation of the function automatically.

In [15]:
auto_batch_convolve = jax.vmap(convolve)

In [16]:
auto_batch_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [17]:
# It does this by tracing the function like jax.jit then automatically adding batch axes at the beginning of each input.

In [18]:
jax.make_jaxpr(auto_batch_convolve)(xs, ws)

{ lambda ; a:i32[2,5] b:f32[2,3]. let
    c:i32[2,3] = slice[limit_indices=(2, 3) start_indices=(0, 0) strides=None] a
    d:f32[2] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] c b
    e:i32[2,3] = slice[limit_indices=(2, 4) start_indices=(0, 1) strides=None] a
    f:f32[2] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] e b
    g:i32[2,3] = slice[limit_indices=(2, 5) start_indices=(0, 2) strides=None] a
    h:f32[2] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] g b
    i:f32[2,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1)
      sharding=None
    ] d
    j:f32[2,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1)
      sharding=None
    ] f
    k:f32[2,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1)
      sharding=None
    ] h
    l

In [19]:
jax.make_jaxpr(convolve)(x, w)

{ lambda ; a:i32[5] b:f32[3]. let
    c:i32[3] = slice[limit_indices=(3,) start_indices=(0,) strides=None] a
    d:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] c b
    e:i32[3] = slice[limit_indices=(4,) start_indices=(1,) strides=None] a
    f:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] e b
    g:i32[3] = slice[limit_indices=(5,) start_indices=(2,) strides=None] a
    h:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] g b
    i:f32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] d
    j:f32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] f
    k:f32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] h
    l:f32[3] = concatenate[dimension=0] i j k
  in (l,

In [20]:
# You can also convolve single w on every item of x by setting in_axes=[0, None], 0 representing x and None representing w.