In [17]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental.shard_map import shard_map
import numpy as np

In [21]:
mesh = Mesh(np.array(jax.devices()).reshape(2, 2), axis_names=("data", "model"))

x = jax.random.normal(jax.random.PRNGKey(0), (4, 10))
A = jax.random.normal(jax.random.PRNGKey(0), (10, 10))

sharding = NamedSharding(mesh, PartitionSpec("data", None))


# fsdp style is batch dimension sharded
x = jax.device_put(x, NamedSharding(mesh, PartitionSpec("data", None)))
A = jax.device_put(A, NamedSharding(mesh, PartitionSpec("data", None)))

def sharded_matmul(x, A):
    A = jax.lax.all_gather(A, axis_name="data", axis=0, tiled=True)
    return x @ A

sharded_matmul = shard_map(sharded_matmul, mesh=mesh, in_specs=(PartitionSpec("data", None), PartitionSpec("data", None)), out_specs=PartitionSpec("data", None))
with mesh:
    B = sharded_matmul(x, A)

jax.debug.visualize_array_sharding(B)

In [41]:
mesh = Mesh(np.array(jax.devices()).reshape(2, 2), axis_names=("data", "model"))

x = jax.random.normal(jax.random.PRNGKey(0), (4, 10))
A = jax.random.normal(jax.random.PRNGKey(0), (10, 10))

# tp style model dimension sharded
# x = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, "model")))
# A = jax.device_put(A, NamedSharding(mesh, PartitionSpec(None, "model")))

def sharded_matmul(x, A):
    x = jax.lax.all_gather(x, axis_name="model", axis=1, tiled=True)
    temp = x @ A
    return temp
    

sharded_matmul = shard_map(sharded_matmul, mesh=mesh, in_specs=(PartitionSpec(None, "model"), PartitionSpec(None, "model")), out_specs=PartitionSpec(None, "model"))
with mesh:
    B = sharded_matmul(x, A)

jax.debug.visualize_array_sharding(B)

In [49]:
mesh = Mesh(np.array(jax.devices()).reshape(2, 2), axis_names=("data", "model"))

x = jax.random.normal(jax.random.PRNGKey(0), (4, 10))
W1 = jax.random.normal(jax.random.PRNGKey(0), (10, 40))
W2 = jax.random.normal(jax.random.PRNGKey(0), (40, 10))

# tp style model dimension sharded
# x = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, "model")))
# A = jax.device_put(A, NamedSharding(mesh, PartitionSpec(None, "model")))

def sharded_matmul(x, W1, W2):
    x = jax.lax.all_gather(x, axis_name="model", axis=1, tiled=True) # [B, D_y] -> [B, D]
    temp = x @ W1  # [B, D] * [D, F_y] -> [B, F_y]
    temp = temp @ W2  # [B, F_y] * [F_y, D] -> [B, D] {U_y}
    return jax.lax.psum_scatter(temp, axis_name="model", scatter_dimension=1, tiled=True)
    

sharded_matmul = shard_map(sharded_matmul, mesh=mesh, in_specs=(PartitionSpec(None, "model"), PartitionSpec(None, "model"), PartitionSpec("model", None)), out_specs=PartitionSpec(None, "model"))
with mesh:
    B = sharded_matmul(x, W1, W2)

jax.debug.visualize_array_sharding(B)