In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,2,3'

In [4]:
# Device parallelism for SPMD in JAX

In [7]:
import jax
import jax.numpy as jnp

In [6]:
jax.devices()

[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2)]

In [8]:
a = jnp.arange(5)

In [9]:
a.device

CudaDevice(id=0)

In [14]:
b = jnp.array([2,3,4], device=jax.devices()[1])

In [15]:
b.device

CudaDevice(id=1)

In [17]:
# When creating jax arrays from scratch, you also need to create it's sharding

In [18]:
a.sharding

SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)

In [20]:
jax.debug.visualize_array_sharding(a)

In [21]:
jax.debug.visualize_array_sharding(b)

To create an array with a non-trivial sharding, you can define a jax.sharding specification for the array and pass this to jax.device_put().

In [22]:
from jax.sharding import PartitionSpec

In [23]:
mesh = jax.make_mesh((3,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('x',))
print(sharding)

NamedSharding(mesh=Mesh('x': 3), spec=PartitionSpec('x',), memory_kind=device)


In [27]:
arr = jnp.arange(24.0).reshape(3, 8)

In [28]:
arr.shape

(3, 8)

In [29]:
arr_sharded = jax.device_put(arr, sharding)

In [30]:
# Now, arr_sharded's data is split across 3 GPUs.

In [31]:
arr_sharded.device

NamedSharding(mesh=Mesh('x': 3), spec=PartitionSpec('x',), memory_kind=device)

In [32]:
jax.debug.visualize_array_sharding(arr_sharded)

## 1. Automatic parallelism via JIT

Once you have partitioned your data, the easiest way to do parallel computation is to simply pass the data to jax.jit compiled function. In jax, you only need to specify how you want your input and output to be partitioned, and the compiler will figure out how to 1. Partition everything inside, and 2. Compile inter-device communication.

In [33]:
@jax.jit
def f(x):
    return 2 * jnp.sin(x) + 1

In [36]:
result = f(arr_sharded)

In [37]:
jax.debug.visualize_array_sharding(result)

F is an element-wise function. Each accelerator operates on the shard it holds and the output is sharded the same way.

In [39]:
arr_sharded.sharding == result.sharding

True

In [40]:
arr_sharded.shape

(3, 8)

In [48]:
@jax.jit
def f(x):
    return x.sum(axis=0)

In [49]:
result = f(arr_sharded)

In [50]:
jax.debug.visualize_array_sharding(result)

In [53]:
# The result is replicated since each GPU must communicate to get the result

In [54]:
arr.sum(axis=1)

Array([ 28.,  92., 156.], dtype=float32)

In [55]:
@jax.jit
def f(x):
    return x.sum(axis=1)
result = f(arr_sharded)
jax.debug.visualize_array_sharding(result)

In [56]:
# In this case, each result is held individually

## 2. Semi-automated sharding with constraints

In [57]:
# If you want some control over output shard, you can define it as follows:

In [73]:
@jax.jit
def f(x):
    out = x.sum(axis=0)
    mesh = jax.make_mesh((3,), ('x',))
    sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('x'))
    return jax.lax.with_sharding_constraint(out, sharding)

In [74]:
result = f(arr_sharded)

In [75]:
jax.debug.visualize_array_sharding(result)

## 3. Manual Partitioning with shard_map

jax.experimental.shard_map.shard_map() works by mapping a function across a particular mesh of devices

In [76]:
from jax.experimental.shard_map import shard_map

In [77]:
mesh = jax.make_mesh((3,), ('x',))

In [88]:
#Function for a given shard
@jax.jit
def f_elementwise(x):
  jax.debug.print(str(x.shape))
  return 2 * jnp.sin(x) + 1

In [89]:
f_elementwise_sharded = shard_map(
    f_elementwise, 
    mesh=mesh,
    in_specs=PartitionSpec('x'),
    out_specs = PartitionSpec('x')
)

In [90]:
arr = jnp.arange(24)

In [91]:
f_elementwise_sharded(arr)

(8,)
(8,)
(8,)


Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116902,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804214, -0.99998045, -0.07314587,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279494, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244087], dtype=float32)

In [92]:
# As it can be observed, the function only sees "its" shard.

In [93]:
# Therefore, aggregation operations like sum operate independently. If you want it across devices, you need to use something like jax.lax.psum

In [96]:
def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

result = shard_map(f, mesh=mesh, in_specs=PartitionSpec('x'), out_specs=PartitionSpec())(arr_sharded)

In [97]:
result

Array(276., dtype=float32)

In [99]:
result.sharding

NamedSharding(mesh=Mesh('x': 3), spec=PartitionSpec(), memory_kind=device)

In [102]:
# Using jax.jit, if you shard the leading axis of both x and weights in the same way, then the matrix multiplication will automatically happen in parallel

In [103]:
# Alternatively, you can use jax.lax.with_sharding_constraint() in the function to automatically distribute unsharded inputs