In [2]:
import jax

jax.devices()

[cuda(id=0), cuda(id=1)]

In [3]:
import jax.numpy as jnp

arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()

{cuda(id=0)}

In [4]:
arr.sharding

SingleDeviceSharding(device=cuda(id=0))

In [5]:
jax.debug.visualize_array_sharding(arr)

In [6]:
# Pardon the boilerplate; constructing a sharding will become easier in future!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((1, 2))
mesh = jax.sharding.Mesh(devices, ("x", "y"))
sharding = jax.sharding.NamedSharding(mesh, P("x", "y"))
print(sharding)

NamedSharding(mesh=Mesh('x': 1, 'y': 2), spec=PartitionSpec('x', 'y'))


In [7]:
arr_shared = jax.device_put(arr, sharding)
print(arr_shared)
jax.debug.visualize_array_sharding(arr_shared)

[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


In [8]:
@jax.jit
def layer(x, weights, bias):
    return jax.nn.sigmoid(x @ weights + bias)

In [9]:
import numpy as np

rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)

Array([0.02138916, 0.8931117 , 0.5989196 , 0.9774251 ], dtype=float32)

In [10]:
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), "x")
sharding = jax.sharding.NamedSharding(mesh, P("x"))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)

Array([0.02138915, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

In [11]:
jax.debug.visualize_array_sharding(x_sharded)

In [13]:
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))