## vmap()

While training a model in a mini batch setting, you'll need to find a way to batch inputs from multiple variables together. In jax, this can be done with vmap. Let's see this in detail with examples.

Say you need to get the dot product of two arrays, like we do for an affine transformation ($y = wx + b$, let's omit the $b$ for now). For a single $x$ and $w$, you can just use the dot() function inside numpy or jax.numpy. 

In [1]:
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(42)

In [2]:
x = jax.random.randint(key, shape=(100, ), minval=0, maxval=10)

newkey, subkey = jax.random.split(key)
w = jax.random.normal(newkey, shape=(1, 100))

single_dot = jnp.dot(w, x)
single_dot

DeviceArray([-8.661343], dtype=float32)

So here $x$ was an array of length 100, and $w$ was a vector of length 100. A single array and a single vector. What if we have to take the dot product of 100 such $x$ and $w$'s ? This is where vmap comes in handy. Or in other words, we have a batch size of 100 for them.

In [3]:
vmapped_dot = jax.vmap(jnp.dot)

BATCH_SIZE = 100

batched_x = jax.random.randint(key, shape=(BATCH_SIZE, 100, ), minval=0, maxval=10)
batched_w = jax.random.normal(newkey, shape=(BATCH_SIZE ,1, 100))

batched_dot = vmapped_dot(batched_w, batched_x)
batched_dot

DeviceArray([[-154.21112  ],
             [  49.17533  ],
             [  11.140234 ],
             [  37.236473 ],
             [  29.498665 ],
             [   9.617361 ],
             [  71.4579   ],
             [  78.10006  ],
             [ -45.04447  ],
             [  56.26046  ],
             [  64.63633  ],
             [  49.38059  ],
             [ -17.98463  ],
             [ -64.39987  ],
             [  38.20652  ],
             [  43.753876 ],
             [ -12.084675 ],
             [   4.5944977],
             [ -19.081417 ],
             [   3.98884  ],
             [ -52.261543 ],
             [  24.712055 ],
             [ -83.08613  ],
             [ -51.1763   ],
             [  15.695245 ],
             [ -75.23836  ],
             [  -4.31831  ],
             [  25.793459 ],
             [ -82.10106  ],
             [  53.141228 ],
             [  52.1035   ],
             [  54.45214  ],
             [  28.216433 ],
             [  14.025606 ],
             [

If you want to add $b$ and make it a full fledged affine transformation .... 

In [4]:
y = lambda w, x, b: jnp.dot(w, x) + b

batched_b = jax.random.randint(subkey, shape=(BATCH_SIZE, ), minval=0, maxval=10)

res = jax.vmap(y)(batched_w, batched_x, batched_b)
res


DeviceArray([[-147.21112  ],
             [  58.17533  ],
             [  15.140234 ],
             [  38.236473 ],
             [  37.498665 ],
             [  14.617361 ],
             [  74.4579   ],
             [  85.10006  ],
             [ -43.04447  ],
             [  59.26046  ],
             [  69.63633  ],
             [  53.38059  ],
             [ -11.984631 ],
             [ -56.39987  ],
             [  43.20652  ],
             [  48.753876 ],
             [  -3.0846748],
             [   8.594498 ],
             [ -16.081417 ],
             [   8.98884  ],
             [ -46.261543 ],
             [  29.712055 ],
             [ -75.08613  ],
             [ -51.1763   ],
             [  22.695244 ],
             [ -67.23836  ],
             [  -3.3183098],
             [  33.793457 ],
             [ -80.10106  ],
             [  55.141228 ],
             [  59.1035   ],
             [  57.45214  ],
             [  31.216433 ],
             [  19.025606 ],
             [