In [24]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [25]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

tpu


In [26]:
import jax
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [27]:
import jax.numpy as jnp
def dot(v:jax.Array,u:jax.Array):
    return jnp.vdot(v,u)

In [28]:
dot(jnp.ones(shape=(3,)),jnp.ones(shape=(3,)))

DeviceArray(3., dtype=float32)

# Compare vmap and pmap

In [29]:
from jax  import random
arr = random.normal(random.PRNGKey(42),shape=(20_000_000,3))
v = arr[:10_000_000,:]
u = arr[10_000_000:,:]
v.shape,u.shape

((10000000, 3), (10000000, 3))

## vmap

* f      --> (v,)  . (u,) => ()
* f_vmap --> (b,v) . (b,u) => (b,)

In [30]:
dot_vmap = jax.vmap(dot)
res = dot_vmap(v,u)
res.shape

(10000000,)

## pmap

* SPMD[Single Program Multiple Data]
* f --> (v,) . (u,) ==> ()
* f_pmap --> (d,v) . (d,u) => (d,)
* d here number of devices
* result type of pmap will `SharedDeivceArray`, which logically appear to be single array but physically they live multiple device
* if non pmap fn call use this shareddevice then the `data move to single device` then the fn cal



In [31]:
dot_pmap = jax.pmap(dot)
try:
    res = dot_pmap(v,u)
except ValueError as e:
    print(f"\N{Cross mark}\N{Police Cars Revolving Light}{e}")

❌🚨compiling computation that requires 10000000 logical devices, but only 8 XLA devices are available (num_replicas=10000000, num_partitions=1)


* require 10_000_000 logical devices .
* The important distinction between vmap() and pmap() is that the mapped axis
size must be less than or equal to the number of local XLA devices available,
as returned by jax.local_device_count()

In [32]:
v_parallel = v.reshape(jax.local_device_count(),-1,3)
u_parallel = u.reshape(jax.local_device_count(),-1,3)
v_parallel.shape,u_parallel.shape

((8, 1250000, 3), (8, 1250000, 3))

In [33]:
res = dot_pmap(v_parallel,u_parallel)
res.shape

(8,)

In [34]:
jax.make_jaxpr(dot_pmap)(v_parallel,u_parallel)

{ lambda ; a:f32[8,1250000,3] b:f32[8,1250000,3]. let
    c:f32[8] = xla_pmap[
      axis_name=<axis 0x78999cbe96c0>
      axis_size=8
      backend=None
      call_jaxpr={ lambda ; d:f32[1250000,3] e:f32[1250000,3]. let
          f:f32[3750000] = reshape[dimensions=None new_sizes=(3750000,)] d
          g:f32[3750000] = reshape[dimensions=None new_sizes=(3750000,)] e
          h:f32[] = dot_general[
            dimension_numbers=(((0,), (0,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] f g
        in (h,) }
      devices=None
      donated_invars=(False, False)
      global_arg_shapes=(None, None)
      global_axis_size=None
      in_axes=(0, 0)
      name=dot
      out_axes=(0,)
    ] a b
  in (c,) }

```python
 lambda ; d:f32[1250000,3] e:f32[1250000,3]. let
          f:f32[3750000] = reshape[dimensions=None new_sizes=(3750000,)] d
          g:f32[3750000] = reshape[dimensions=None new_sizes=(3750000,)] e
```
* reshape done by the vdot  to avoid that we need to use vmap

In [35]:
dot_parallel = jax.pmap(jax.vmap(dot))
res= dot_parallel(v_parallel,u_parallel)
res.shape

(8, 1250000)

In [36]:
type(res)

jax.interpreters.pxla._ShardedDeviceArray

In [38]:
%timeit xp = dot_parallel(v_parallel,u_parallel).block_until_ready()

51.5 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [39]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)


In [41]:
rng_key = random.PRNGKey(42)
vs = random.normal(rng_key, shape=(16,3))
v1s = vs[:8,:]
v2s = vs[8:,:]


In [42]:
jax.vmap(dot)(v1s,v2s)

DeviceArray([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205,
             -1.3696793 ,  2.744793  ,  1.7936493 , -1.1743435 ],            dtype=float32)

In [43]:
jax.pmap(dot)(v1s,v2s)

ShardedDeviceArray([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205,
                    -1.3696793 ,  2.744793  ,  1.7936493 , -1.1743435 ],                   dtype=float32)

In [44]:
dot_v = jax.jit(jax.vmap(dot))
x = dot_v(v1s,v2s)

In [45]:
dot_pjo = jax.jit(jax.pmap(dot))
x = dot_pjo(v1s,v2s)



In [46]:
dot_pji = jax.pmap(jax.jit(dot))
x = dot_pji(v1s,v2s)

In [47]:
dot_p = jax.pmap(dot)
x = dot_p(v1s,v2s)

In [48]:
%timeit dot_v(v1s,v2s).block_until_ready()

2.33 ms ± 260 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [49]:
%timeit dot_pjo(v1s,v2s).block_until_ready()


The slowest run took 4.80 times longer than the fastest. This could mean that an intermediate result is being cached.
47.4 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [53]:
%timeit dot_pji(v1s,v2s).block_until_ready()

37.1 ms ± 924 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [54]:
%timeit dot_p(v1s,v2s).block_until_ready()

24.5 ms ± 4.04 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [55]:
jax.make_jaxpr(dot_v)(v1s,v2s)

{ lambda ; a:f32[8,3] b:f32[8,3]. let
    c:f32[8] = xla_call[
      call_jaxpr={ lambda ; d:f32[8,3] e:f32[8,3]. let
          f:f32[8] = dot_general[
            dimension_numbers=(((1,), (1,)), ((0,), (0,)))
            precision=None
            preferred_element_type=None
          ] d e
        in (f,) }
      name=dot
    ] a b
  in (c,) }

In [56]:
jax.make_jaxpr(dot_pjo)(v1s,v2s)


{ lambda ; a:f32[8,3] b:f32[8,3]. let
    c:f32[8] = xla_call[
      call_jaxpr={ lambda ; d:f32[8,3] e:f32[8,3]. let
          f:f32[8] = xla_pmap[
            axis_name=<axis 0x78999cbebbe0>
            axis_size=8
            backend=None
            call_jaxpr={ lambda ; g:f32[3] h:f32[3]. let
                i:f32[] = dot_general[
                  dimension_numbers=(((0,), (0,)), ((), ()))
                  precision=None
                  preferred_element_type=None
                ] g h
              in (i,) }
            devices=None
            donated_invars=(False, False)
            global_arg_shapes=(None, None)
            global_axis_size=None
            in_axes=(0, 0)
            name=dot
            out_axes=(0,)
          ] d e
        in (f,) }
      name=dot
    ] a b
  in (c,) }

In [57]:
jax.make_jaxpr(dot_pji)(v1s,v2s)

{ lambda ; a:f32[8,3] b:f32[8,3]. let
    c:f32[8] = xla_pmap[
      axis_name=<axis 0x78999f292980>
      axis_size=8
      backend=None
      call_jaxpr={ lambda ; d:f32[3] e:f32[3]. let
          f:f32[] = xla_call[
            call_jaxpr={ lambda ; g:f32[3] h:f32[3]. let
                i:f32[] = dot_general[
                  dimension_numbers=(((0,), (0,)), ((), ()))
                  precision=None
                  preferred_element_type=None
                ] g h
              in (i,) }
            name=dot
          ] d e
        in (f,) }
      devices=None
      donated_invars=(False, False)
      global_arg_shapes=(None, None)
      global_axis_size=None
      in_axes=(0, 0)
      name=dot
      out_axes=(0,)
    ] a b
  in (c,) }

In [58]:
jax.make_jaxpr(dot_p)(v1s,v2s)

{ lambda ; a:f32[8,3] b:f32[8,3]. let
    c:f32[8] = xla_pmap[
      axis_name=<axis 0x78999cbebbe0>
      axis_size=8
      backend=None
      call_jaxpr={ lambda ; d:f32[3] e:f32[3]. let
          f:f32[] = dot_general[
            dimension_numbers=(((0,), (0,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] d e
        in (f,) }
      devices=None
      donated_invars=(False, False)
      global_arg_shapes=(None, None)
      global_axis_size=None
      in_axes=(0, 0)
      name=dot
      out_axes=(0,)
    ] a b
  in (c,) }