<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/Pmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
print(f"Device count: {jax.device_count()}")

Device count: 8


In [2]:
import jax.numpy as jnp

def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [3]:
dot(jnp.array([1., 1., 1]), jnp.array([1., 2., -1]))

Array(2., dtype=float32)

In [4]:
from jax import random

rng_key = random.PRNGKey(42)

vs = random.normal(rng_key, shape=(20_000_000, 3)) #Generates a 2-dimensional array of random numbers

v1s = vs[:10_000_000, :] #split this array into two
v2s = vs[10_000_000:, :]

In [5]:
v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

In [6]:
dot_batched = jax.jit(jax.vmap(dot))

x_vmap = dot_batched(v1s, v2s)

x_vmap.shape

(10000000,)

In [11]:
dot_parallel = jax.pmap(dot)

x_pmap = dot_parallel(v1s, v2s)

ValueError: compiling computation that requires 10000000 logical devices, but only 8 XLA devices are available (num_replicas=10000000)

In [12]:
v1sp = v1s.reshape((8, v1s.shape[0]//8, v1s.shape[1])) #reshape into eight chunks
v2sp = v2s.reshape((8, v2s.shape[0]//8, v2s.shape[1]))

In [13]:
v1sp.shape

(8, 1250000, 3)

In [14]:
x_pmap = dot_parallel(v1sp, v2sp)
x_pmap.shape

(8,)

In [15]:
dot_parallel = jax.pmap(jax.vmap(dot))
x_pmap = dot_parallel(v1sp, v2sp)

In [16]:
x_pmap.shape

(8, 1250000)

In [17]:
type(x_pmap)

jaxlib._jax.ArrayImpl

In [18]:
x_pmap = x_pmap.reshape((x_pmap.shape[0] * x_pmap.shape[1])) #eliminates the mapping axis, this is a logical reshape to a flattened array
x_pmap.shape

(10000000,)

##Pmap = vmap

In [19]:
vs = random.normal(rng_key, shape=(16, 3)) #generate small arrays to fit on 8 devices

v1s = vs[:8, :]
v2s = vs[8:, :]

In [20]:
jax.vmap(dot)(v1s, v2s) #autovectorizes the function

Array([ 0.16187823,  0.21840437, -1.3397005 ,  0.17735794,  0.93532217,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [21]:
jax.pmap(dot)(v1s, v2s) #parallelize the function

Array([ 0.16187824,  0.21840435, -1.3397005 ,  0.17735793,  0.9353221 ,
       -0.85470366,  0.25243354,  0.81216115], dtype=float32)

In [22]:
dot_v = jax.jit(jax.vmap(dot)) #compiles and warms up jitted version
x = dot_v(v1s, v2s)

In [23]:
dot_pjo = jax.jit(jax.pmap(dot)) #compiles and warms up pmap() with outer jit()
x = dot_pjo(v1s, v2s)

In [24]:
dot_pji = jax.pmap(jax.jit(dot)) #compile and warm up with inner jit
x = dot_pji(v1s, v2s)

In [25]:
dot_p = jax.pmap(dot) #compile and warm up without explicit jit
x = dot_p(v1s, v2s)

In [26]:
%timeit dot_v(v1s, v2s).block_until_ready()

11.3 µs ± 2.27 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [27]:
%timeit dot_pjo(v1s, v2s).block_until_ready()

984 µs ± 82.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [28]:
%timeit dot_pji(v1s, v2s).block_until_ready()

1.12 ms ± 374 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [29]:
%timeit dot_p(v1s, v2s).block_until_ready()

1.03 ms ± 353 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


##In_axes parameters

In [30]:
vs = random.normal(rng_key, shape=(16, 3))
v1s = vs[:8, :]
v2s = vs[8:, :]