**S02P03_tutorial_auto_vectorization_in_jax.ipynb**

Arz

2024 APR 10 (WED)

reference:
https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# manual vectorization

ex) convolution of two vectors.

In [4]:
def convolve(x, y):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i-1:i+2], y))
    return jnp.array(output)

In [5]:
x = jnp.arange(5)
y = jnp.array([2., 3., 4.])

In [6]:
convolve(x, y)

Array([11., 20., 29.], dtype=float32)

In [7]:
x_batch = jnp.stack([x, x, x.at[1:-1].set([1, 0, 1]), x])
y_batch = jnp.stack([y, y, y, jnp.array([7., 4., 2.])])

print(x_batch)
print(y_batch)

[[0 1 2 3 4]
 [0 1 2 3 4]
 [0 1 0 1 4]
 [0 1 2 3 4]]
[[2. 3. 4.]
 [2. 3. 4.]
 [2. 3. 4.]
 [7. 4. 2.]]


## naive looping

not efficient.

In [8]:
def looping_batch_convolve(x_batch, y_batch):
    output = []
    for i in range(x_batch.shape[0]):
        output.append(convolve(x_batch[i], y_batch[i]))
    return jnp.stack(output)

In [9]:
looping_batch_convolve(x_batch, y_batch)

Array([[11., 20., 29.],
       [11., 20., 29.],
       [ 3.,  6., 19.],
       [ 8., 21., 34.]], dtype=float32)

In [10]:
%timeit looping_batch_convolve(x_batch, y_batch)

26.1 ms ± 1.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## mathy

ex) using linear algebra

can be tricky to write or understand the code.

In [11]:
def mathy_batch_convolve(x_batch, y_batch):
    output = []
    for i in range(1, x_batch.shape[-1] - 1):
        output.append(jnp.sum(x_batch[:, i-1:i+2]*y_batch, axis=1))
    return jnp.stack(output, axis=1)

In [12]:
mathy_batch_convolve(x_batch, y_batch)

Array([[11., 20., 29.],
       [11., 20., 29.],
       [ 3.,  6., 19.],
       [ 8., 21., 34.]], dtype=float32)

In [13]:
%timeit mathy_batch_convolve(x_batch, y_batch)

The slowest run took 4.78 times longer than the fastest. This could mean that an intermediate result is being cached.
2.81 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


# automatic vectorization

In [14]:
auto_batch_convolve = vmap(convolve)

In [15]:
auto_batch_convolve(x_batch, y_batch)

Array([[11., 20., 29.],
       [11., 20., 29.],
       [ 3.,  6., 19.],
       [ 8., 21., 34.]], dtype=float32)

In [16]:
%timeit auto_batch_convolve(x_batch, y_batch)

12.6 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


## specify batch dimension

- input: **in-axes**
- output: **out-axes**

In [17]:
auto_batch_convolve = vmap(convolve, in_axes=1, out_axes=1)

auto_batch_convolve(x_batch.T, y_batch.T) 

Array([[11., 11.,  3.,  8.],
       [20., 20.,  6., 21.],
       [29., 29., 19., 34.]], dtype=float32)

### case: dimension inconsistency

In [18]:
auto_batch_convolve = vmap(convolve, in_axes=1, out_axes=0)

# auto_batch_convolve(x_batch.T, y_batch)  # forbidden: throws error

# dimension inconsistency
# x_batch: int32[5, 4]
# y_batch: float32[4, 3]

**fix**

in_axes=[axis to take for arg1 (x_batch), axis to take for arg2 (y_batch)]

In [19]:
auto_batch_convolve = vmap(convolve, in_axes=[1, 0], out_axes=0)

auto_batch_convolve(x_batch.T, y_batch)

Array([[11., 20., 29.],
       [11., 20., 29.],
       [ 3.,  6., 19.],
       [ 8., 21., 34.]], dtype=float32)

In [20]:
auto_batch_convolve = vmap(convolve, in_axes=[1, 0], out_axes=1)

auto_batch_convolve(x_batch.T, y_batch)

Array([[11., 11.,  3.,  8.],
       [20., 20.,  6., 21.],
       [29., 29., 19., 34.]], dtype=float32)

## case: when only one of the arguments is batched

In [21]:
auto_batch_convolve = vmap(convolve, in_axes=[0, None])

auto_batch_convolve(x_batch, y) 

Array([[11., 20., 29.],
       [11., 20., 29.],
       [ 3.,  6., 19.],
       [11., 20., 29.]], dtype=float32)

# combining transformations

In [22]:
auto_batch_convolve = vmap(convolve)

In [23]:
auto_batch_convolve_jit = jit(auto_batch_convolve)

In [24]:
auto_batch_convolve_jit(x_batch, y_batch)

Array([[11., 20., 29.],
       [11., 20., 29.],
       [ 3.,  6., 19.],
       [ 8., 21., 34.]], dtype=float32)

In [25]:
%timeit auto_batch_convolve_jit(x_batch, y_batch).block_until_ready()

163 µs ± 18.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
