# Vectorisation

One of the promises of JAX is to make vectorisation great again via the use of syntactic sugar decorators that describe what inputs are batched onto what outputs. The goal of this notebook is to show how this can be done in practice as well as how this is translated in terms of low-level code. 

## JAX imports

## Beginner
### Prerequisites
NumPy - (some exposure to Numba is helpful)

### Imports

In [105]:
import inspect 

from jax import vmap, make_jaxpr
import jax.numpy as jnp

import numpy as np

### Example

To compare the vectorisation implementation of JAX to the NumPy one, let's take the following example:

In [106]:
def indexing_function(x, y):
    # Here x is a vector of floats, and y is a vector of ints
    return x[y]

We will use the following array for our tests:

In [107]:
N = 10

In [14]:
indexing_function(np.random.randn(N), np.random.randint(N))

-0.6750140445123347

How does it react to batched inputs?

In [59]:
B = 3

In [60]:
indexing_function(np.random.randn(B, N), np.random.randint(N, size=B))

IndexError: index 8 is out of bounds for axis 0 with size 3

In [61]:
indexing_function(np.random.randn(B, N), np.random.randint(N))

IndexError: index 7 is out of bounds for axis 0 with size 3

OK so we need to modify it.

In [63]:
def complicated_indexing_function(x, y):
    # Here x is a vector of floats, and y is a vector of ints
    return x[..., y]

In [64]:
complicated_indexing_function(np.random.randn(B, N), np.random.randint(N))

array([-0.16716791,  0.17872859,  0.03090031])

In [66]:
complicated_indexing_function(np.random.randn(B, N), np.random.randint(N, size=B))

array([[ 1.13307029, -0.37103888, -0.37103888],
       [-0.41673613,  0.26876945,  0.26876945],
       [-0.18368725, -0.30409671, -0.30409671]])

Really not what we want!

### NumPy-style vectorisation

Instead of trying to be smart, let's use NumPy:

In [76]:
np_vectorised_indexing_function = np.vectorize(indexing_function, signature="(n),()->()")

In [77]:
np_vectorised_indexing_function(np.random.randn(B, N), np.random.randint(N))

array([ 0.61877187, -1.09455421,  0.84442176])

And the JAX vectorisation:

In [78]:
jax_vectorised_indexing_function = jnp.vectorize(indexing_function, signature="(n),()->()")

In [79]:
jax_vectorised_indexing_function(np.random.randn(B, N), np.random.randint(N))

DeviceArray([-0.1660363, -0.239407 , -0.7275573], dtype=float32)

So what is the difference?

In [92]:
batch_input = np.random.randn(10000, N)
batch_index = np.random.randint(N, size=10000)

In [93]:
%timeit np_vectorised_indexing_function(batch_input, batch_index)

15.8 ms ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [95]:
jax_batch_input = jnp.asarray(batch_input)
jax_batch_index = jnp.asarray(batch_index)

In [97]:
%timeit jax_vectorised_indexing_function(jax_batch_input, jax_batch_index).block_until_ready()

2.22 ms ± 60 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Why is it faster? Because it's multi-threaded in the background!

### Vectorised map

On the other hand one can pick vmap: `jnp.vectorize` is a wrapper around the vmap functionality, this is useful in the case when the batching dimension is not the first one for example.

In [102]:
vmapped_indexing = vmap(indexing_function, in_axes=(1, 0))  
# here we are saying that the input will be batched along 
# the second dimension for the input, and the first for the index, this helps with not having to do shape arithmetics.

In [104]:
vmapped_indexing(np.random.randn(N, 3), np.random.randint(N, size=3))

DeviceArray([-0.02996093, -0.6977211 , -2.1147773 ], dtype=float32)

### Questions:

#### Q1: 
Reimplement this manually vectorised function using `vmap`, and compare the generated code using `make_jaxpr`

In [109]:
def just_a_function(x, y):
    a = x[..., 0] * y[..., 1]
    b = x[..., 1] * y[..., 0]
    return a + b

In [111]:
just_a_function(np.random.randn(4, 3, 2), np.random.randn(4, 3, 2))

array([[-0.13200268,  1.34672235,  3.11636931],
       [-2.57596326, -0.94320519,  1.48318021],
       [-0.82265851, -0.44076371,  0.46802494],
       [ 0.25318832, -0.65429168,  0.48984008]])

#### Q2:
Using `jnp.vectorize`, vectorise the following function with respect to the matrix `a`:
```python
def solve(a, b):
    return jnp.solve(a, b)
```

## Intermediate / Advanced
### Prerequisites
- Beginner vectorisation
- Beginner automatic differentiation
- Beginner loops (Advanced)

Now that we know how to vectorise, let's give an example where it's not only just a convenient wrapper but also a useful computational tool: we will see how to vectorise the JVP call we learned about in the automatic differentiation notebook.

### Imports

In [126]:
from functools import partial 

from jax import make_jaxpr, jvp, vmap
import jax.numpy as jnp
from jax.random import normal, PRNGKey

import numpy as np

### Example

Let's take a simple example:

In [115]:
def fun(x):
    return jnp.sin(jnp.sum(x))

Say we want to compute its JVP against a number of random vectors:

In [124]:
def jvp_fun(x, key, d=100):
    n = x.shape[0]
    vectors = normal(key, shape=(n, d))
    return jvp(fun, (x,), (vectors,))

In [125]:
jvp_fun(jnp.array([0., 1.]), PRNGKey(42))

TypeError: Gather op must have one slice size for every input dimension; got: len(slice_sizes)=1, input_shape.rank=2

It doesn't work out of the box it seems... Let's try and obey the syntax of JVP:

In [151]:
def jvp_fun(x, key, d=20):
    n = x.shape[0]
    vectors = normal(key, shape=(n, d))
    return jvp(fun, (jnp.repeat(x.reshape(-1, 1), d, 1),), (vectors,))[1]

In [152]:
jvp_fun(jnp.array([0., 1.]), PRNGKey(42))

DeviceArray([ 0.67665976,  0.29954007, -1.0683215 , -0.22646905,
              0.98523813,  1.1727225 ,  0.40628698,  0.43852165,
              0.5092319 , -0.10380521,  1.348889  , -0.7610773 ,
              0.14572985,  0.07581756,  0.70275015,  1.2393789 ,
             -0.09658202, -1.0514277 ,  0.57944995,  0.21703981],            dtype=float32)

OK it's working so what's the problem here? `fun` is being relinearised at the same point $d$ times for no reason!

You can execute the line below to see this

In [157]:
# make_jaxpr(jvp_fun)(jnp.array([0., 1.]), PRNGKey(42))

We can actually solve this problem by using `vmap`:

In [158]:
def vmap_jvp_fun(x, key, d=20):
    n = x.shape[0]
    vectors = normal(key, shape=(n, d))
    local_fun = lambda vec: jvp(fun, (x,), (vec,))[1]
    return vmap(local_fun, in_axes=(1,))(vectors)

In [159]:
vmap_jvp_fun(jnp.array([0., 1.]), PRNGKey(42))

DeviceArray([ 0.67665976,  0.29954007, -1.0683215 , -0.22646905,
              0.98523813,  1.1727225 ,  0.40628698,  0.43852165,
              0.5092319 , -0.10380521,  1.348889  , -0.7610773 ,
              0.14572985,  0.07581756,  0.70275015,  1.2393789 ,
             -0.09658202, -1.0514277 ,  0.57944995,  0.21703981],            dtype=float32)

Execute the following line to compare with the naive manual approach

In [161]:
# make_jaxpr(vmap_jvp_fun)(jnp.array([0., 1.]), PRNGKey(42))

### Questions:

#### Q1:
Vectorise the following bubble sort algorithm using the method of your choice:
```python
def bubbleSort(arr): 
    n = len(arr) 
    res = np.copy(arr)
    for i in range(n-1): 
        for j in range(0, n-i-1): 
            if res[j] > res[j+1]: 
                res[j], res[j+1] = res[j+1], res[j]
    return res   
```