<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/Vmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import jax
import jax.numpy as jnp

In [3]:
def dot(v1, v2):
  return jnp.vdot(v1, v2) #use vdot() to calculate the dot product

In [4]:
dot(jnp.array([1., 1., 1.]), jnp.array([1., 2., -1]))  #calculates th edot product between 2 vectors

Array(2., dtype=float32)

In [6]:
rng_key = jax.random.PRNGKey(42) #generates a random number generator key

In [8]:
vs = jax.random.normal(rng_key, shape =(20, 3)) #generate a two-dimensional array of random numbers

In [32]:
v1s = vs[:10, :] #split array into two parts; the first 10 go the first list and the second 10 go to the second list
v2s = vs[10:, :]

##Vectorization

In [11]:
#naive vectorization to two lists of vectors
v1s.shape, v2s.shape

((10, 3), (10, 3))

In [12]:
dot(v1s, v2s) #one number - wrong answer

Array(0.09785453, dtype=float32)

In [13]:
#Naive generation of results one item at a time
[dot(v1s[i], v2s[i]) for i in range(v1s.shape[0])] #apply the function elementwise in a python list comprehension gives list of arrays instead of single array

[Array(-0.21668759, dtype=float32),
 Array(0.0147948, dtype=float32),
 Array(-0.7736949, dtype=float32),
 Array(-0.37052184, dtype=float32),
 Array(-0.19051453, dtype=float32),
 Array(-0.44745094, dtype=float32),
 Array(-1.2089032, dtype=float32),
 Array(1.1151277, dtype=float32),
 Array(0.02767059, dtype=float32),
 Array(2.1480343, dtype=float32)]

In [14]:
#manual vectorization
def dot_vectorized(v1s, v2s):
  return jnp.einsum('ij, ij -> i', v1s, v2s)  #rewritten to support arrays as inputs

In [15]:
dot_vectorized(v1s, v2s)

Array([-0.21668759,  0.0147948 , -0.7736949 , -0.37052184, -0.19051453,
       -0.44745094, -1.2089032 ,  1.1151277 ,  0.02767059,  2.1480343 ],      dtype=float32)

In [16]:
#automatic vmap
dot_vmapped = jax.vmap(dot)

In [17]:
dot_vmapped(v1s, v2s)

Array([-0.21668759,  0.0147948 , -0.7736949 , -0.37052184, -0.19051453,
       -0.44745094, -1.2089032 ,  1.1151277 ,  0.02767059,  2.1480343 ],      dtype=float32)

In [18]:
%timeit [dot(v1s[i], v2s[i]).block_until_ready() for i in range(v1s.shape[0])]

2.97 ms ± 411 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
%timeit dot_vectorized(v1s, v2s).block_until_ready()

222 µs ± 5.64 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [20]:
%timeit dot_vmapped(v1s, v2s).block_until_ready

573 µs ± 19.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [21]:
#JIT both functions
dot_vectorized_jitted = jax.jit(dot_vectorized)
dot_vmapped_jitted = jax.jit(dot_vmapped)

In [23]:
#warmup
dot_vectorized_jitted(v1s, v2s);
dot_vmapped_jitted(v1s, v2s);

In [24]:
%timeit dot_vectorized_jitted(v1s, v2s).block_until_ready()

11.7 µs ± 2.51 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [25]:
%timeit dot_vmapped_jitted(v1s, v2s).block_until_ready

10.2 µs ± 1.8 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


##jaxpr for the dot products

In [26]:
jax.make_jaxpr(dot) (jnp.array([1., 1., 1.]), jnp.array([1., 1., -1]))

{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[3][39m. [34;1mlet
    [39;22mc[35m:f32[][39m = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] a b
  [34;1min [39;22m(c,) }

In [27]:
jax.make_jaxpr(dot_vectorized) (v1s, v2s)

{ [34;1mlambda [39;22m; a[35m:f32[10,3][39m b[35m:f32[10,3][39m. [34;1mlet
    [39;22mc[35m:f32[10][39m = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] a b
  [34;1min [39;22m(c,) }

In [28]:
jax.make_jaxpr(dot_vmapped) (v1s, v2s) #dot general - from jax.lax

{ [34;1mlambda [39;22m; a[35m:f32[10,3][39m b[35m:f32[10,3][39m. [34;1mlet
    [39;22mc[35m:f32[10][39m = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] a b
  [34;1min [39;22m(c,) }

##in_axes parameter

In [33]:
def scaled_dot(v1, v2, koeff):
  return koeff * jnp.vdot(v1, v2)
v1s_ = v1s
v2s_ = v2s.T
k = 1.0 #identity since coefficient is 1

In [34]:
v1s_.shape, v2s_.shape

((10, 3), (3, 10))

In [35]:
scaled_dot_batched = jax.vmap(scaled_dot, in_axes=(0,1,None)) #for first paranter, iterate over axis = 0

In [36]:
scaled_dot_batched(v1s_, v2s_, k)

Array([-0.21668759,  0.0147948 , -0.7736949 , -0.37052184, -0.19051453,
       -0.44745094, -1.2089032 ,  1.1151277 ,  0.02767059,  2.1480343 ],      dtype=float32)

In [37]:
def scaled_dot(data, koeff):
  return koeff*jnp.vdot(data['a'], data['b']) #now the functon consumes a dict and a scalar

In [39]:
scaled_dot_batched = jax.vmap(scaled_dot, in_axes =({'a':0, 'b': 1}, None)) #marks axes for dict and scalar

In [40]:
scaled_dot_batched({'a': v1s_, 'b':v2s_}, k)

Array([-0.21668759,  0.0147948 , -0.7736949 , -0.37052184, -0.19051453,
       -0.44745094, -1.2089032 ,  1.1151277 ,  0.02767059,  2.1480343 ],      dtype=float32)

In [41]:
def scale(v, koeff):
  return koeff*v

In [42]:
scale_batched = jax.vmap(scale, in_axes=(0, None), out_axes=(1)) #out_axes - specifies batch dimension for the output

In [43]:
scale_batched(v1s, 2.0)

Array([[-0.05660923,  0.30709183, -2.881758  ,  1.8203408 ,  2.8915725 ,
         1.819189  , -2.8970175 , -2.358762  , -0.48223934, -1.0695378 ],
       [ 0.9342637 , -0.24806564,  1.5117198 , -0.7689932 ,  2.1618133 ,
         1.1146923 ,  1.528375  , -3.8778367 ,  2.430255  ,  0.54135114],
       [ 0.5914059 ,  0.4338463 ,  1.0428193 ,  2.2796466 , -0.11258642,
         0.43811437, -0.48309395,  0.71252924, -2.7904441 ,  3.0802484 ]],      dtype=float32)