<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/Pmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [77]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
print(f"Device count: {jax.device_count()}")

Device count: 8


In [78]:
import jax.numpy as jnp

def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [79]:
dot(jnp.array([1., 1., 1]), jnp.array([1., 2., -1]))

Array(2., dtype=float32)

In [80]:
from jax import random

rng_key = random.PRNGKey(42)

vs = random.normal(rng_key, shape=(20_000_000, 3)) #Generates a 2-dimensional array of random numbers

v1s = vs[:10_000_000, :] #split this array into two
v2s = vs[10_000_000:, :]

In [81]:
v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

In [82]:
dot_batched = jax.jit(jax.vmap(dot))

x_vmap = dot_batched(v1s, v2s)

x_vmap.shape

(10000000,)

In [83]:
dot_parallel = jax.pmap(dot)

x_pmap = dot_parallel(v1s, v2s)

ValueError: compiling computation that requires 10000000 logical devices, but only 8 XLA devices are available (num_replicas=10000000)

In [84]:
v1sp = v1s.reshape((8, v1s.shape[0]//8, v1s.shape[1])) #reshape into eight chunks
v2sp = v2s.reshape((8, v2s.shape[0]//8, v2s.shape[1]))

In [85]:
v1sp.shape

(8, 1250000, 3)

In [86]:
x_pmap = dot_parallel(v1sp, v2sp)
x_pmap.shape

(8,)

In [87]:
dot_parallel = jax.pmap(jax.vmap(dot))
x_pmap = dot_parallel(v1sp, v2sp)

In [88]:
x_pmap.shape

(8, 1250000)

In [89]:
type(x_pmap)

jaxlib._jax.ArrayImpl

In [90]:
x_pmap = x_pmap.reshape((x_pmap.shape[0] * x_pmap.shape[1])) #eliminates the mapping axis, this is a logical reshape to a flattened array
x_pmap.shape

(10000000,)

##Pmap = vmap

In [91]:
vs = random.normal(rng_key, shape=(16, 3)) #generate small arrays to fit on 8 devices

v1s = vs[:8, :]
v2s = vs[8:, :]

In [92]:
jax.vmap(dot)(v1s, v2s) #autovectorizes the function

Array([ 0.16187823,  0.21840437, -1.3397005 ,  0.17735794,  0.93532217,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [93]:
jax.pmap(dot)(v1s, v2s) #parallelize the function

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [94]:
dot_v = jax.jit(jax.vmap(dot)) #compiles and warms up jitted version
x = dot_v(v1s, v2s)

In [95]:
dot_pjo = jax.jit(jax.pmap(dot)) #compiles and warms up pmap() with outer jit()
x = dot_pjo(v1s, v2s)



In [96]:
dot_pji = jax.pmap(jax.jit(dot)) #compile and warm up with inner jit
x = dot_pji(v1s, v2s)

In [97]:
dot_p = jax.pmap(dot) #compile and warm up without explicit jit
x = dot_p(v1s, v2s)

In [98]:
%timeit dot_v(v1s, v2s).block_until_ready()

11.3 µs ± 1.69 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [99]:
%timeit dot_pjo(v1s, v2s).block_until_ready()

1.46 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [100]:
%timeit dot_pji(v1s, v2s).block_until_ready()

858 µs ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [101]:
%timeit dot_p(v1s, v2s).block_until_ready()

1.07 ms ± 334 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


##In_axes parameters

In [102]:
vs = random.normal(rng_key, shape=(16, 3))
v1s = vs[:8, :]
v2s = vs[8:, :]

In [103]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [104]:
dot_pmapped = jax.pmap(dot, in_axes=(0, 0))

In [105]:
dot_pmapped(v1s, v2s)

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [106]:
v1s.T.shape, v2s.shape

((3, 8), (8, 3))

In [107]:
jax.pmap(dot, in_axes=(1, 0)) (v1s.T, v2s)

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [108]:
v1s.T.shape, v2s.T.shape

((3, 8), (3, 8))

In [109]:
def scaled_dot(v1, v2, koeff):
  return koeff*jnp.vdot(v1,v2)

In [110]:
v1s_ = v1s
v2s_ = v2s.T
k = 1.0

In [111]:
v1s_.shape, v2s_.shape

((8, 3), (3, 8))

In [112]:
jax.pmap(scaled_dot, in_axes=(0, 1, None)) (v1s_, v2s_, k) #none - just copy as it is, no operation needed

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [113]:
def scaled_dot(data, koeff):
  return koeff * jnp.vdot(data['a'], data['b'])

In [114]:
jax.pmap(scaled_dot, in_axes=({'a':0, 'b':1}, None)) ({'a':v1s_, 'b':v2s_}, k)

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [115]:
def scale(v, koeff):
  return koeff*v

In [116]:
scale_pmapped = jax.pmap(scale, in_axes=(0, None), out_axes=(1))

In [117]:
res = scale_pmapped(v1s, 2.0)

In [118]:
v1s.shape, res.shape

((8, 3), (3, 8))

##Large array example

In [119]:
from jax import random

rng_key = random.PRNGKey(42)

vs = random.normal(rng_key, shape=(20_000_000, 3)) #Generates a 2-dimensional array of random numbers

v1s = vs[:10_000_000, :].T #split this array into two
v2s = vs[10_000_000:, :].T

In [120]:
v1s.shape, v2s.shape #first dimension contains components of a vector; the 2nd dimension contains vectors

((3, 10000000), (3, 10000000))

In [121]:
v1sp = v1s.reshape((v1s.shape[0], 8, v1s.shape[1]//8)) #split 2nd dimension into 2 dimensions: groups and vectors
v2sp = v1s.reshape((v2s.shape[0], 8, v2s.shape[1]//8))
v1sp.shape, v2sp.shape

((3, 8, 1250000), (3, 8, 1250000))

In [122]:
dot_parallel = jax.pmap(jax.vmap(dot, in_axes=(1,1)), in_axes=(1,1)) #vmap works on the inner(1,1) which maps over the vectors a sit doesnt see th ebatch dimensions while pmap maps over the batch dimensions so outer (1,1)

In [123]:
x_pmap = dot_parallel(v1sp, v2sp)

In [124]:
x_pmap = x_pmap.reshape((x_pmap.shape[0] * x_pmap.shape[1])) #eliminates the group dimension
x_pmap.shape

(10000000,)

In [125]:
jax.numpy.all(x_pmap == x_vmap)

Array(False, dtype=bool)

##Collective ops

In [126]:
arr = jnp.array(range(8))

In [127]:
norm = jax.pmap(lambda x: x/jax.lax.psum(x, axis_name='p'), axis_name='p')

In [128]:
norm(arr)

Array([0.        , 0.03571429, 0.07142857, 0.10714286, 0.14285715,
       0.17857143, 0.21428572, 0.25      ], dtype=float32)

In [129]:
jnp.sum(norm(arr))

Array(1., dtype=float32)

In [130]:
arr = jnp.array(range(200))

In [131]:
arr= arr.reshape(8, 25) #reshape to number of xla devices
arr.shape

(8, 25)

In [132]:
norm = jax.pmap(lambda x: x/jax.lax.psum(jnp.sum(x), axis_name='p'), axis_name='p')

In [133]:
narr = norm(arr)
narr.shape

(8, 25)

In [134]:
jnp.sum(narr)

Array(0.99999994, dtype=float32)

In [135]:
arr = jnp.array(range(200))
arr= arr.reshape(8, 25)
arr.shape

(8, 25)

In [136]:
norm = jax.pmap(lambda x: x/jax.lax.psum(jnp.sum(x), axis_name='p', axis_index_groups=[[0,1], [2,3], [4,5], [6,7]]), axis_name = 'p') #provides indices for four groups

In [137]:
narr = norm(arr)
narr.shape

(8, 25)

In [138]:
jnp.sum(narr)

Array(3.9999998, dtype=float32)

In [139]:
jnp.sum(narr[:2]), jnp.sum(narr[2:4]), jnp.sum(narr[4:6]), jnp.sum(narr[6:])

(Array(1., dtype=float32),
 Array(0.99999994, dtype=float32),
 Array(1., dtype=float32),
 Array(1.0000001, dtype=float32))

##Mixing collective ops

In [140]:
arr = jnp.array(range(200))
arr= arr.reshape(8, 25)
arr.shape

(8, 25)

In [142]:
f = jax.pmap(jax.vmap(lambda x: x/jax.lax.pmax(x, axis_name='v')/jax.lax.pmax(x, axis_name='p'), axis_name='v'), axis_name='p') #divide pmax of devices/ pmax across devices

In [143]:
f(arr)

Array([[0.        , 0.00023674, 0.00047081, 0.00070225, 0.0009311 ,
        0.00115741, 0.00138122, 0.00160256, 0.00182149, 0.00203804,
        0.00225225, 0.00246416, 0.0026738 , 0.00288121, 0.00308642,
        0.00328947, 0.0034904 , 0.00368924, 0.00388601, 0.00408076,
        0.0042735 , 0.00446429, 0.00465313, 0.00484007, 0.00502513],
       [0.00291545, 0.00301484, 0.00311311, 0.00321027, 0.00330635,
        0.00340136, 0.00349532, 0.00358825, 0.00368016, 0.00377107,
        0.003861  , 0.00394997, 0.00403798, 0.00412505, 0.00421121,
        0.00429646, 0.00438081, 0.00446429, 0.0045469 , 0.00462866,
        0.00470958, 0.00478967, 0.00486895, 0.00494743, 0.00502513],
       [0.003861  , 0.00391585, 0.00397007, 0.00402369, 0.0040767 ,
        0.00412913, 0.00418098, 0.00423225, 0.00428297, 0.00433314,
        0.00438276, 0.00443185, 0.00448042, 0.00452846, 0.004576  ,
        0.00462304, 0.00466959, 0.00471565, 0.00476124, 0.00480635,
        0.00485101, 0.0048952 , 0.00493895, 0.