In [67]:
import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

import jax
import jax.numpy as jnp

In [68]:
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (128, 500, 500))

In [69]:
@jax.pmap
def f(x):
    return jnp.sum(x)


def f2(x):
    return jnp.sum(x)

In [70]:
x_split = x.reshape(8, 16, 500, 500)
f(x_split)
%timeit -n 5 -r 5 pmapped_result = f(x_split).block_until_ready()

17.9 ms ± 502 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [71]:
%timeit -n 5 -r 5 result = f2(x).block_until_ready()

17.2 ms ± 624 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


#### combine pmapped results in a function

In [82]:
@jax.pmap
def mean_over_spatial(x):
    m = jnp.mean(x, axis=(1, 2))
    return m

In [83]:
x = jax.random.normal(rng, (8, 16, 256, 256, 10))
m = mean_over_spatial(x)

print(m.shape)

(8, 16, 10)


#### iterator shape

In [94]:
def generator():
    for i in range(1000):
        yield jnp.ones((16, 256, 256, 1))

In [95]:
def device_batch(batch_iterator):
    num_devices = jax.local_device_count()
    batch = []

    for i, b in enumerate(batch_iterator):
        if i % num_devices == num_devices - 1:
            batch.append(b)
            batch = jax.tree_util.tree_map(lambda *x: jnp.stack(x), *batch)
            yield batch
            batch = []
        else:
            batch.append(b)

In [96]:
batch = next(device_batch(generator()))

print(batch.shape)

(8, 16, 256, 256, 1)
