# Playing with Jax parallelization.

In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [2]:
import jax
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [3]:
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
print(arr.devices())
print(arr.sharding)
jax.debug.visualize_array_sharding(arr)

{CpuDevice(id=0)}
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)


In [5]:
from jax.sharding import PartitionSpec as P
mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
sharded_arr = arr.to_device(sharding)
print(sharded_arr.devices())
print(sharded_arr.sharding)
jax.debug.visualize_array_sharding(sharded_arr)

{CpuDevice(id=7), CpuDevice(id=0), CpuDevice(id=4), CpuDevice(id=3), CpuDevice(id=2), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=1)}
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)


In [31]:
def forward(activations, weights_1, weights_2):
  activations = jnp.dot(activations, weights_1)
  activations = jax.nn.relu(activations)
  activations = jnp.dot(activations, weights_2)
  return jax.nn.softmax(activations)

def loss_fn(logits, labels=jnp.array([1] + 9*[0])):
  return -jnp.mean(jnp.sum(labels * jnp.log(logits), axis=1))

@jax.jit
def train_step(activations, weights_1, weights_2):
  logits = forward(activations, weights_1, weights_2)
  grads = jax.grad(loss_fn)(logits)
  return logits

key = jax.random.key(1337)
activations = jax.random.normal(key, (1024, 8192))
weights_1 = jax.random.normal(key, (8192, 8192))
weights_2 = jax.random.normal(key, (8192, 10))
%timeit train_step(activations, weights_1, weights_2).block_until_ready()

mesh = jax.make_mesh((8,), ('batch',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('batch'))
activations = activations.to_device(sharding)
weights_1 = weights_1.to_device(jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()))
weights_2 = weights_2.to_device(jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()))
%timeit train_step(activations, weights_1, weights_2).block_until_ready()

529 ms ± 38.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
651 ms ± 72.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
