<a href="https://colab.research.google.com/github/ShawonAshraf/annotated-jax/blob/main/playground/jax_spmd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [2]:
import jax.numpy as jnp

x = jnp.arange(32).reshape(4,8)
x.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [3]:
x.sharding

SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), memory_kind=device)

In [4]:
jax.debug.visualize_array_sharding(x)

In [5]:
from jax.sharding import PartitionSpec as P

In [6]:
mesh_shape = (2, 4) # must equal to the number of devices when multiplied

mesh = jax.make_mesh(mesh_shape, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

print(f"x before sharding: {x}")

sharded_x = jax.device_put(x, sharding)
print(f"x after sharding: {sharded_x}")

jax.debug.visualize_array_sharding(sharded_x)


NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'), memory_kind=device)
x before sharding: [[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]]
x after sharding: [[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]]


In [7]:
sharded_x.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)}

but how to know which data went to which device?

the compiler decides on that, you just need to provide the sharding strategy

In [8]:
@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(sharded_x)
result

Array([[ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
        -0.9178486 ,  0.44116896,  2.3139732 ],
       [ 2.9787164 ,  1.824237  , -0.08804226, -0.99998045, -0.07314599,
         1.8403342 ,  2.9812148 ,  2.3005757 ],
       [ 0.42419332, -0.92279506, -0.50197446,  1.2997544 ,  2.8258905 ,
         2.6733112 ,  0.98229736, -0.69244075],
       [-0.81115675,  0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 ,
        -0.32726777, -0.97606325,  0.19192469]], dtype=float32)

In [9]:
jax.debug.visualize_array_sharding(result)

In [10]:
@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(sharded_x)
jax.debug.visualize_array_sharding(result)
print(result)

[48 52 56 60 64 68 72 76]


jit decides how to manage inter device communication and how to shard data.

that's for the automatic sharding. You can still do semi-auto and manual sharding with sharding strategies

In [11]:
# manual sharding

from jax.experimental.shard_map import shard_map

manual_mesh = jax.make_mesh((8, ), ("x", ))
manual_sharding = jax.sharding.NamedSharding(manual_mesh, P('x'))

msharded_x = jax.device_put(result, manual_sharding)
jax.debug.visualize_array_sharding(msharded_x)

In [12]:
x = jnp.arange(32)
def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=manual_mesh, in_specs=P('x'), out_specs=P('x'))(x)

Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

In [25]:
import numpy as np

# a quick batching attempt
batch_size = 32
ds_size = 16 * batch_size
dataset = np.array([ np.arange(8) for _ in range(ds_size) ])
print(dataset.shape)

def data_loader(dataset=dataset, batch_size=batch_size):
  # shuffle
  indices = np.arange(len(dataset))
  np.random.shuffle(indices)
  shuffled_dataset = dataset[indices]

  for i in range(0, len(shuffled_dataset), batch_size):
    yield jnp.array(dataset[i:i+batch_size])

# ================
for batch in data_loader():
  print(batch.shape)
  print(batch.devices())
  break

(512, 8)
(32, 8)
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}


In [26]:
# shard dataloader
mesh_shape = (2, 4)
mes = jax.make_mesh(mesh_shape, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mes, P('x', 'y'))

n_shards = mesh_shape[0] * mesh_shape[1]
print(f"{n_shards=}")
shard_size = len(dataset) // n_shards
print(f"{shard_size=}")

n_shards=8
shard_size=64


In [27]:
shards = [ dataset[i:i+shard_size] for i in range(0, ds_size, shard_size) ]
print(f"{len(shards)=}")

len(shards)=8


In [44]:
def sharded_data_loader(shard_index):
    shard = shards[shard_index]
    # Use your existing data loader to process the shard
    for batch in data_loader(shard):
        device_list = list(sharding.addressable_devices)
        device = device_list[shard_index]
        yield jax.device_put(batch, device)

# ===========
for i in range(n_shards):
    for batch in sharded_data_loader(i):
        print(batch.shape)
        print(batch.devices())


(32, 8)
{TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)}
(32, 8)
{TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)}
(32, 8)
{TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)}
(32, 8)
{TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)}
(32, 8)
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
(32, 8)
{TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)}
(32, 8)
{TpuDevice(id=7, 

The mesh grid is 2x4, hence each device gets two batches here of size 32. (64 shard size, 32 + 32)