[Deep learning with JAX](https://github.com/che-shr-cat/JAX-in-Action)

In [1]:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax 
from jax import random
from jax.lib import xla_bridge
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
print(f"number of cores:", jax.local_device_count())
print(f"using: ", xla_bridge.get_backend().platform)

number of cores: 8
using:  cpu


In [2]:
import jax.numpy as jnp

def dot(v1, v2):
    return jnp.dot(v1, v2)

dot(jnp.array([1.,1.,1.]), jnp.array([1.,2., -1]))

Array(2., dtype=float64)

In [3]:
# generating two large list of vectors
rng_key = random.PRNGKey(42)
vs = random.normal(rng_key, (20_000_000, 3))

v1s = vs[:10_000_000, :]
v2s = vs[10_000_000:, :]
v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

In [4]:
dot_batched = jax.jit(jax.vmap(dot, in_axes=(0, 0)))
x_vmap = dot_batched(v1s, v2s)
x_vmap.shape

(10000000,)

In [5]:
dot_parallel = jax.pmap(dot)
# x_pmap = dot_parallel(v1s, v2s) # error

In [6]:
v1sp = v1s.reshape(8, v1s.shape[0]//8, v1s.shape[1])
v2sp = v2s.reshape(8, v2s.shape[0]//8, v2s.shape[1])
v1sp.shape, v2sp.shape

((8, 1250000, 3), (8, 1250000, 3))

In [13]:
# x_pmap = dot_parallel(v1sp, v2sp)
# x_pmap.shape

In [39]:
dot_parallel = jax.pmap(jax.vmap(dot))
x_pmap = dot_parallel(v1sp, v2sp)
print(x_pmap.shape, type(x_pmap))

# reshaping back to the original shape
x_pmap = x_pmap.reshape(x_pmap.shape[0]*x_pmap.shape[1])
print(x_pmap.shape)
print("VMAP and PMAP have the same results:", jax.numpy.allclose(x_vmap, x_pmap))

(8, 1250000) <class 'jaxlib.xla_extension.ArrayImpl'>
(10000000,)
VMAP and PMAP have the same results: True


In [40]:
# Measuring the difference between vmap() and pmap()

dot_v = jax.jit(jax.vmap(dot))

dot_pji = jax.pmap(jax.vmap(jax.jit(dot)))
dot_pjo = jax.jit(jax.pmap(jax.vmap(dot)))
dot_vj = jax.vmap(jax.jit(dot))
dot_pjo(v1sp, v2sp);
dot_pji(v1sp, v2sp);
dot_v(v1s, v2s);
dot_vj(v1s, v2s);



In [36]:
%timeit dot_v(v1s, v2s).block_until_ready()
%timeit dot_pjo(v1sp, v2sp).block_until_ready()
%timeit dot_pji(v1sp, v2sp).block_until_ready()
%timeit dot_vj(v1s, v2s).block_until_ready()

40.7 ms ± 757 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
648 ms ± 22.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
123 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
40.7 ms ± 269 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
# Explicitly using the in_axes parameter
vs = random.normal(rng_key, (16, 3))
v1s = vs[:8, :]
v2s = vs[8:, :]

print(v1s.shape, v2s.shape)

def dot(v1, v2):
    return jnp.vdot(v1, v2)

print(jax.pmap(dot, in_axes=(0, 0))(v1s, v2s)[:3])
print(jax.pmap(dot, in_axes=(1, 0))(v1s.T, v2s)[:3])
print(jax.pmap(dot, in_axes=(1, 1))(v1s.T, v2s.T)[:3])

(8, 3) (8, 3)
[-0.74235083  0.64188436 -4.86064789]
[-0.74235083  0.64188436 -4.86064789]
[-0.74235083  0.64188436 -4.86064789]


In [58]:
# using the in_axes parameter for broadcasting
def scaled_dot(v1, v2, koeff):
    return koeff * jnp.vdot(v1, v2)

v1s_ = v1s 
v2s_ = v2s.T 
k = 1.0
v1s_.shape, v2s_.shape

((8, 3), (3, 8))

In this example the last parameter of the function, the scaling coefficient, is replicated across all the devices.

In [63]:
scaled_dot_batched = jax.pmap(scaled_dot, in_axes=(0, 1, None))
print(scaled_dot_batched(v1s_, v2s_, k)[:3])

[-0.74235083  0.64188436 -4.86064789]


In [71]:
# using the in_axes parameter with a Python container
def scaled_dot(data, koeff):
    return koeff * jnp.vdot(data['a'], data['b'])

scaled_dot_pmapped = jax.pmap(scaled_dot, in_axes=({'a':0, 'b':1}, None))
scaled_dot_pmapped({'a':v1s_, 'b':v2s_}, k)[:3]

Array([-0.74235083,  0.64188436, -4.86064789], dtype=float64)

In [70]:
# using the out_axes parameter
def scale(v, koeff):
    return v * koeff

scaled_pmapped = jax.pmap(scale, in_axes=(0, None), out_axes=1)
res = scaled_pmapped(v1s, 2.0)
v1s.shape, res.shape, jnp.allclose(res.T, v1s*2.0)

((8, 3), (3, 8), Array(True, dtype=bool))

In [74]:
# Large array example
vs = random.normal(rng_key, (20_000_000, 3))
v1s = vs[:10_000_000, :].T 
v2s = vs[10_000000:, :].T
v1s.shape, v2s.shape

((3, 10000000), (3, 10000000))

In [82]:
n_device = jax.device_count()
v1sp = v1s.reshape(v1s.shape[0], n_device, v1s.shape[1]//n_device)
v2sp = v2s.reshape(v2s.shape[0], n_device, v2s.shape[1]//n_device)
v1sp.shape, v2sp.shape

((3, 8, 1250000), (3, 8, 1250000))

**A**: Asks vmap to use the second dimension to map over (vmap does not see the group dimension, so its second
dimension is the vector dimension)  
**B**: Asks pmap to use the second (group) dimension over

In [83]:
dot_parallel = jax.pmap(
    jax.vmap(dot, in_axes=(1, 1)),  # A
    in_axes=(1, 1))                 # B  
x_pmap = dot_parallel(v1sp, v2sp)
print(x_pmap.shape)
x_pmap = x_pmap.reshape(x_pmap.shape[0] * x_pmap.shape[1])
print(x_pmap.shape)
print(jax.numpy.allclose(x_vmap, x_pmap))

(8, 1250000)
(10000000,)
True


[Parallel operators](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)

In [91]:
# using collective ops and axis_name, if n elements be larger that the number of devices-> error
arr = jnp.array(range(8))
norm = jax.pmap(lambda x: x/jax.lax.psum(x, axis_name='p'), axis_name='p')(arr)
print(norm, jnp.sum(norm))

[0.         0.03571429 0.07142857 0.10714286 0.14285714 0.17857143
 0.21428571 0.25      ] 1.0


In [96]:
# Normalization example for a larger array
arr = jnp.array(range(200))
arr = arr.reshape(n_device, arr.shape[0]//n_device)
print(arr.shape)

norm = jax.pmap(lambda x: x/jax.lax.psum(jnp.sum(x), axis_name='p'), axis_name='p')
narr = norm(arr)
print(narr.shape, jnp.sum(narr))


(8, 25)
(8, 25) 1.0


In [99]:
# Normalization by groups
arr = jnp.array(range(200))
arr = arr.reshape(n_device, arr.shape[0]//n_device)

norm = jax.pmap(lambda x: x/jax.lax.psum(
    jnp.sum(x),
    axis_name='p',
    axis_index_groups=[[0,1], [2,3], [4,5], [6,7]]),
    axis_name='p')
narr = norm(arr)
print(narr.shape, jnp.sum(narr))
print(jnp.sum(narr[:2]), jnp.sum(narr[2:4]), jnp.sum(narr[4:6]), jnp.sum(narr[6:8]))

(8, 25) 3.9999999999999996
1.0 1.0 1.0000000000000002 1.0000000000000002


- Nested maps  
mixing collective ops

Here we create a mix of `pmap()` and `vmap()` with two different collective operations. one `pmax()` operation
performed inside each array located on a separate device and another `pmax()` operation
performed between different devices.  

**A**: A function using two collective ops across different axes


In [104]:
arr = jnp.array(range(200))
arr = arr.reshape(n_device, arr.shape[0] // n_device)
print(arr.shape)

f = jax.pmap(
    jax.vmap(
        lambda x: jax.lax.pmax(x, axis_name="v") / jax.lax.pmax(x, axis_name="p"), # A
        axis_name="v",
    ),
    axis_name="p",
)
print(f(arr).shape)

(8, 25)
(8, 25)


We can rewrite the global normalization example from Listing 7.19 to get rid of the
`jnp.sum()` call on each device by replacing it with a global sum including the batch axis
(introduced by `vmap()`) as well. We do it the following way: we calculate a sum using a
collective operation across both axes simultaneously `axis_name=('p','v')` into the `psum()` function.


In [108]:
arr = jnp.array(range(200))
arr = arr.reshape(n_device, arr.shape[0] // n_device)
print(arr.shape)

norm = jax.pmap(
    jax.vmap(lambda x: x / jax.lax.psum(x, axis_name=("p", "v")), 
             axis_name="v"),
    axis_name="p",
)
narr = norm(arr)
print(narr.shape, jnp.sum(narr))

(8, 25)
(8, 25) 1.0


In [110]:
# nested pmap example, for small arrays

arr = jnp.array(range(8)).reshape(2, 4)
n = jax.pmap(jax.pmap(lambda x: x/jax.lax.psum(x, axis_name=('rows', 'cols')), axis_name='cols'), axis_name='rows')(arr)

print(jnp.sum(n))

1.0


In [111]:
# nested pmap example, using decorator style 

from functools import partial

@partial(jax.pmap, axis_name='rows')
@partial(jax.pmap, axis_name='cols')
def n(x):
    return x/jax.lax.psum(x, axis_name=('rows', 'cols'))

arr = jnp.array(range(8)).reshape(2, 4)
print(jnp.sum(n(arr)))

1.0
