## vmap()

vmap is another of the key components of jax. What it does is that it lets you pass a batch of inputs to your functions in a single go instead of passing them one by one. Think of it like this. You have 10 arrays which you need to pass to a function. In the typical python fashion, you would be writing a loop. With vmap, you can send them as a batch, take advantage of vector and matrix operations supported by your accelerator (GPU, TPU yada yada) and make your computation way faster. 

Let's consider a simple affine transformation over some arrays.If you don't know what an affine transformation is, try refereshing your memory where you have seen this equation:

$$
y = xW^{T} + b
$$

In [1]:
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(42)
key, *subkeys = jax.random.split(key, 3) # main key + 2 keys more

In [2]:
@jax.jit
def affine(x, w, b):
    return x @ w.T + b

In [3]:
# say for this affine transformation, we are taking 
# a 10 dim vector
# and making it a 2 dim vector

in_dim = 10
out_dim = 2

x = jax.random.normal(key, shape=(in_dim, ))

w = jax.random.normal(subkeys[0], shape=(out_dim, in_dim))

b = jax.random.normal(subkeys[1], shape=(out_dim, ))

In [4]:
affine(x, w, b)

DeviceArray([-1.2992545,  6.513557 ], dtype=float32)

But that was over one `x`, `w` and `b`. What if we want to pass a batch as intended?

In [5]:
batch_size = 10

xs = jax.random.normal(key, shape=(batch_size, 10, ))
ws = jax.random.normal(subkeys[0], shape=(batch_size, out_dim, in_dim))
bs = jax.random.normal(subkeys[1], shape=(batch_size, out_dim, ))

If you want to add $b$ and make it a full fledged affine transformation .... 

In [6]:
affine(xs, ws, bs)

TypeError: dot_general requires contracting dimensions to have the same shape, got (10,) and (2,).

## Enter vmap:

In [7]:
vmapped_affine = jax.vmap(affine)
vmapped_affine(xs, ws, bs)

DeviceArray([[ 2.3045597 ,  0.98218215],
             [ 1.6167651 , -2.1813567 ],
             [-3.882108  , -1.0218511 ],
             [-0.12499112, -0.6038313 ],
             [-2.508927  ,  0.8527329 ],
             [-0.86322856,  3.2404249 ],
             [-0.5951458 ,  1.4468247 ],
             [-0.06645101, -0.23431951],
             [-0.41957128, -7.0053816 ],
             [-1.878496  , -2.0947793 ]], dtype=float32)

It worked! 


## What if I want to vmap on one axis?

Now if you might be yelling at me for having separate w and b for each x, especially if you're someone who is used to affine transformation (actually linear but all linear transformations are affine anyway) from machine learning or pytorch's nn.Linear. Fret not. We can have a single w and b for all xs. Actually let's write a simple linear discriminator (or classification) model, that takes a vector / batch of vector as inputs and gives a dim 2 vector as output as class probabilities. 

(Certainly not going to train this!, just the inference part. Besides I don't have any labels :P )

In [8]:
# using the same w, b values from before
params = (w, b)


@jax.jit
def model(params, x):
    w, b = params
    logits = x @ w.T + b
    
    return jax.nn.softmax(logits, axis=-1) # convert to probability distribution

Okay but we have to keep only one value of params for all x, right? You can tell vmap to ignore params. How? using `in_axes`

In [9]:
vmapped_model = jax.vmap(model, in_axes=(None, 0))
vmapped_model(params, xs)

DeviceArray([[6.6032819e-04, 9.9933964e-01],
             [7.1236588e-02, 9.2876339e-01],
             [2.5964549e-01, 7.4035454e-01],
             [9.0411073e-01, 9.5889255e-02],
             [4.9198368e-01, 5.0801635e-01],
             [9.8948145e-01, 1.0518541e-02],
             [8.3628774e-01, 1.6371234e-01],
             [9.8934537e-01, 1.0654581e-02],
             [9.8608810e-01, 1.3911908e-02],
             [4.2095959e-02, 9.5790404e-01]], dtype=float32)

`in_axes` lets you mention which params of the function you would like to be "vmapped". So say for example you wanted to pass as w, b, x instead of params, x; you could've written

```python
vmapped_model = jax.vmap(model, in_axes=(None, None, 0))
```

A value of None will tell vmap to ignore that particular parameter. (Maintain order, you can randomly assign None !). For everything else, you can mention on which dimension should the array be vectorised. I've used 0 here. You can use something else for more complex tasks. 