<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/Pmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
print(f"Device count: {jax.device_count()}")

Device count: 8


In [2]:
import jax.numpy as jnp

def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [3]:
dot(jnp.array([1., 1., 1]), jnp.array([1., 2., -1]))

Array(2., dtype=float32)

In [4]:
from jax import random

rng_key = random.PRNGKey(42)

vs = random.normal(rng_key, shape=(20_000_000, 3)) #Generates a 2-dimensional array of random numbers

v1s = vs[:10_000_000, :] #split this array into two
v2s = vs[10_000_000:, :]

In [5]:
v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

In [6]:
dot_batched = jax.jit(jax.vmap(dot))

x_vmap = dot_batched(v1s, v2s)

x_vmap.shape

(10000000,)

In [7]:
dot_parallel = jax.pmap(dot)

x_pmap = dot_parallel(v1s, v2s)

ValueError: compiling computation that requires 10000000 logical devices, but only 8 XLA devices are available (num_replicas=10000000)

In [8]:
v1sp = v1s.reshape((8, v1s.shape[0]//8, v1s.shape[1])) #reshape into eight chunks
v2sp = v2s.reshape((8, v2s.shape[0]//8, v2s.shape[1]))

In [9]:
v1sp.shape

(8, 1250000, 3)

In [10]:
x_pmap = dot_parallel(v1sp, v2sp)
x_pmap.shape

(8,)

In [11]:
dot_parallel = jax.pmap(jax.vmap(dot))
x_pmap = dot_parallel(v1sp, v2sp)

In [12]:
x_pmap.shape

(8, 1250000)

In [13]:
type(x_pmap)

jaxlib._jax.ArrayImpl

In [14]:
x_pmap = x_pmap.reshape((x_pmap.shape[0] * x_pmap.shape[1])) #eliminates the mapping axis, this is a logical reshape to a flattened array
x_pmap.shape

(10000000,)

##Pmap = vmap

In [15]:
vs = random.normal(rng_key, shape=(16, 3)) #generate small arrays to fit on 8 devices

v1s = vs[:8, :]
v2s = vs[8:, :]

In [16]:
jax.vmap(dot)(v1s, v2s) #autovectorizes the function

Array([ 0.16187823,  0.21840437, -1.3397005 ,  0.17735794,  0.93532217,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [17]:
jax.pmap(dot)(v1s, v2s) #parallelize the function

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [18]:
dot_v = jax.jit(jax.vmap(dot)) #compiles and warms up jitted version
x = dot_v(v1s, v2s)

In [19]:
dot_pjo = jax.jit(jax.pmap(dot)) #compiles and warms up pmap() with outer jit()
x = dot_pjo(v1s, v2s)



In [20]:
dot_pji = jax.pmap(jax.jit(dot)) #compile and warm up with inner jit
x = dot_pji(v1s, v2s)

In [21]:
dot_p = jax.pmap(dot) #compile and warm up without explicit jit
x = dot_p(v1s, v2s)

In [22]:
%timeit dot_v(v1s, v2s).block_until_ready()

12.9 µs ± 5.05 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [23]:
%timeit dot_pjo(v1s, v2s).block_until_ready()

1.49 ms ± 850 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [24]:
%timeit dot_pji(v1s, v2s).block_until_ready()

793 µs ± 16.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
%timeit dot_p(v1s, v2s).block_until_ready()

970 µs ± 331 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


##In_axes parameters

In [26]:
vs = random.normal(rng_key, shape=(16, 3))
v1s = vs[:8, :]
v2s = vs[8:, :]

In [27]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [28]:
dot_pmapped = jax.pmap(dot, in_axes=(0, 0))

In [29]:
dot_pmapped(v1s, v2s)

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [30]:
v1s.T.shape, v2s.shape

((3, 8), (8, 3))

In [31]:
jax.pmap(dot, in_axes=(1, 0)) (v1s.T, v2s)

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [32]:
v1s.T.shape, v2s.T.shape

((3, 8), (3, 8))

In [33]:
def scaled_dot(v1, v2, koeff):
  return koeff*jnp.vdot(v1,v2)

In [34]:
v1s_ = v1s
v2s_ = v2s.T
k = 1.0

In [35]:
v1s_.shape, v2s_.shape

((8, 3), (3, 8))

In [36]:
jax.pmap(scaled_dot, in_axes=(0, 1, None)) (v1s_, v2s_, k) #none - just copy as it is, no operation needed

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [37]:
def scaled_dot(data, koeff):
  return koeff * jnp.vdot(data['a'], data['b'])

In [38]:
jax.pmap(scaled_dot, in_axes=({'a':0, 'b':1}, None)) ({'a':v1s_, 'b':v2s_}, k)

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [39]:
def scale(v, koeff):
  return koeff*v

In [40]:
scale_pmapped = jax.pmap(scale, in_axes=(0, None), out_axes=(1))

In [41]:
res = scale_pmapped(v1s, 2.0)

In [42]:
v1s.shape, res.shape

((8, 3), (3, 8))

##Large array example

In [43]:
from jax import random

rng_key = random.PRNGKey(42)

vs = random.normal(rng_key, shape=(20_000_000, 3)) #Generates a 2-dimensional array of random numbers

v1s = vs[:10_000_000, :].T #split this array into two
v2s = vs[10_000_000:, :].T

In [44]:
v1s.shape, v2s.shape #first dimension contains components of a vector; the 2nd dimension contains vectors

((3, 10000000), (3, 10000000))

In [45]:
v1sp = v1s.reshape((v1s.shape[0], 8, v1s.shape[1]//8)) #split 2nd dimension into 2 dimensions: groups and vectors
v2sp = v1s.reshape((v2s.shape[0], 8, v2s.shape[1]//8))
v1sp.shape, v2sp.shape

((3, 8, 1250000), (3, 8, 1250000))

In [46]:
dot_parallel = jax.pmap(jax.vmap(dot, in_axes=(1,1)), in_axes=(1,1)) #vmap works on the inner(1,1) which maps over the vectors a sit doesnt see th ebatch dimensions while pmap maps over the batch dimensions so outer (1,1)

In [47]:
x_pmap = dot_parallel(v1sp, v2sp)

In [48]:
x_pmap = x_pmap.reshape((x_pmap.shape[0] * x_pmap.shape[1])) #eliminates the group dimension
x_pmap.shape

(10000000,)

In [49]:
jax.numpy.all(x_pmap == x_vmap)

Array(False, dtype=bool)

##Collective ops

In [51]:
arr = jnp.array(range(8))

In [52]:
norm = jax.pmap(lambda x: x/jax.lax.psum(x, axis_name='p'), axis_name='p')

In [53]:
norm(arr)

Array([0.        , 0.03571429, 0.07142857, 0.10714286, 0.14285715,
       0.17857143, 0.21428572, 0.25      ], dtype=float32)

In [54]:
jnp.sum(norm(arr))

Array(1., dtype=float32)

In [62]:
arr = jnp.array(range(200))

In [63]:
arr= arr.reshape(8, 25) #reshape to number of xla devices
arr.shape

(8, 25)

In [64]:
norm = jax.pmap(lambda x: x/jax.lax.psum(jnp.sum(x), axis_name='p'), axis_name='p')

In [65]:
narr = norm(arr)
narr.shape

(8, 25)

In [66]:
jnp.sum(narr)

Array(0.99999994, dtype=float32)