In [3]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental.shard_map import shard_map
import numpy as np
import time
from jax.experimental.pjit import pjit

## FSDP-style sharding

In [9]:
mesh = Mesh(np.array(jax.devices()).reshape(2, 2), axis_names=("data", "model"))

x = jax.random.normal(jax.random.PRNGKey(0), (4, 10))
A = jax.random.normal(jax.random.PRNGKey(0), (10, 10))

sharding = NamedSharding(mesh, PartitionSpec("data", None))


# fsdp style is batch dimension sharded
x = jax.device_put(x, NamedSharding(mesh, PartitionSpec("data", None)))
A = jax.device_put(A, NamedSharding(mesh, PartitionSpec("data", None)))

def sharded_matmul(x, A):
    A = jax.lax.all_gather(A, axis_name="data", axis=0, tiled=True)
    return x @ A

sharded_matmul = shard_map(sharded_matmul, mesh=mesh, in_specs=(PartitionSpec("data", None), PartitionSpec("data", None)), out_specs=PartitionSpec("data", None))
with mesh:
    B = sharded_matmul(x, A)
    B_non_sharded = x @ A
    assert np.allclose(B, B_non_sharded)

jax.debug.visualize_array_sharding(B)

## Column-wise Tensor Parallelism for MLP

In [None]:
mesh = Mesh(np.array(jax.devices()).reshape(2, 2), axis_names=("data", "model"))

x = jax.random.normal(jax.random.PRNGKey(0), (4, 10))
W1 = jax.random.normal(jax.random.PRNGKey(0), (10, 40))
W2 = jax.random.normal(jax.random.PRNGKey(0), (40, 10))

# tp style model dimension sharded
# x = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, "model")))
# A = jax.device_put(A, NamedSharding(mesh, PartitionSpec(None, "model")))

def sharded_matmul(x, W1, W2):
    x = jax.lax.all_gather(x, axis_name="model", axis=1, tiled=True) # [B, D_y] -> [B, D]
    temp = x @ W1  # [B, D] * [D, F_y] -> [B, F_y]
    temp = temp @ W2  # [B, F_y] * [F_y, D] -> [B, D] {U_y}
    return jax.lax.psum_scatter(temp, axis_name="model", scatter_dimension=1, tiled=True)
    

sharded_matmul = shard_map(sharded_matmul, mesh=mesh, in_specs=(PartitionSpec(None, "model"), PartitionSpec(None, "model"), PartitionSpec("model", None)), out_specs=PartitionSpec(None, "model"))
with mesh:
    B = sharded_matmul(x, W1, W2)
    B_non_sharded = x @ W1 @ W2
    assert np.allclose(B, B_non_sharded)

jax.debug.visualize_array_sharding(B)

(4, 10)


## FSDP + TP style sharding

In [None]:
mesh = Mesh(np.array(jax.devices()).reshape(2, 2), axis_names=("data", "model"))

x = jax.random.normal(jax.random.PRNGKey(0), (4, 10))
W1 = jax.random.normal(jax.random.PRNGKey(0), (10, 40))
W2 = jax.random.normal(jax.random.PRNGKey(0), (40, 10))

def sharded_matmul(x, W1, W2):
    x = jax.lax.all_gather(x, axis_name="model", axis=1, tiled=True) # TP gathers activations
    W1 = jax.lax.all_gather(W1, axis_name="data", axis=0, tiled=True) # FSDP gathers weights
    W2 = jax.lax.all_gather(W2, axis_name="data", axis=1, tiled=True)
    temp = x @ W1
    out = temp @ W2    
    return jax.lax.psum_scatter(out, axis_name="model", scatter_dimension=1, tiled=True)

sharded_matmul = shard_map(sharded_matmul, mesh=mesh, in_specs=(PartitionSpec("data", "model"), PartitionSpec("data", "model"), PartitionSpec("model", "data")), out_specs=PartitionSpec("data", "model"))
sharded_matmul_pjit = pjit(sharded_matmul, in_shardings=(PartitionSpec("data", "model"), PartitionSpec("data", "model"), PartitionSpec("model", "data")), out_shardings=PartitionSpec("data", "model"))
with mesh:
    B = sharded_matmul(x, W1, W2)
    B_pjit = sharded_matmul_pjit(x, W1, W2)
    B_non_sharded = x @ W1 @ W2
    assert np.allclose(B, B_non_sharded)
    assert np.allclose(B, B_pjit)

jax.debug.visualize_array_sharding(B)

## Benchmarking communication

In [75]:
# X = jnp.arange(4).reshape(2, 2)
# print(X)
X = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))

mesh = Mesh(np.array(jax.devices()).reshape(4, ), axis_names=("data", ))
X = jax.device_put(X, NamedSharding(mesh, PartitionSpec("data", None)))

all_gather = shard_map(lambda x: jax.lax.all_gather(x, axis_name="data", axis=0, tiled=True), mesh=mesh, in_specs=PartitionSpec("data", None), out_specs=PartitionSpec(None, None), check_rep=False)
all_reduce = shard_map(lambda x: jax.lax.psum(x, axis_name="data"), mesh=mesh, in_specs=PartitionSpec(None, None), out_specs=PartitionSpec(None, None), check_rep=False)
# reduce scatter about same speed as all gather if
reduce_scatter = shard_map(lambda x: jax.lax.psum_scatter(x, axis_name="data", scatter_dimension=0, tiled=True), mesh=mesh, in_specs=PartitionSpec(None, None), out_specs=PartitionSpec("data", None), check_rep=False)
all_to_all = shard_map(lambda x: jax.lax.all_to_all(x, axis_name="data", split_axis=1, concat_axis=0, tiled=True), mesh=mesh, in_specs=PartitionSpec("data", None), out_specs=PartitionSpec(None, "data"), check_rep=False)

def benchmark_func(fn, num_iters=25):
    start = time.time()
    jax.debug.visualize_array_sharding(X)
    for _ in range(5):
        X_after_fn = fn(X)
    end = time.time()
    X_after_fn.block_until_ready()

    print(f"Warmup time: {end - start} seconds")
    start = time.time()
    for _ in range(num_iters):
        X_after_fn = fn(X)
    end = time.time()

    X_after_fn.block_until_ready()
    # print(X_after_fn)
    jax.debug.visualize_array_sharding(X_after_fn)
    print(f"Time per iteration: {(end - start) / num_iters} seconds")


with mesh:
    # print("All gather")
    # benchmark_func(all_gather)

    # print("All reduce")
    # benchmark_func(all_reduce)

    # print("Reduce scatter")
    # benchmark_func(reduce_scatter)

    print("all to all")
    benchmark_func(all_to_all)

    X_T = all_to_all(X)
    X_T_T = all_to_all(X_T)
    assert np.allclose(X, X_T_T)

# jax.debug.visualize_array_sharding(X_after_all_gather)


all to all


Warmup time: 3.9934356212615967 seconds


Time per iteration: 0.8007647323608399 seconds


In [20]:
from flax import nnx

class Attention(nnx.Module):
    def __init__(self, hidden_dim, num_heads, rngs):
        self.num_heads = num_heads
        self.Wq = nnx.Linear(
            in_features=hidden_dim,
            out_features=hidden_dim,
            use_bias=False,
            kernel_init=nnx.with_partitioning(nnx.initializers.kaiming_normal(), PartitionSpec(None, "model")),
            rngs=rngs
        )
        self.Wk = nnx.Linear(
            in_features=hidden_dim,
            out_features=hidden_dim,
            use_bias=False,
            kernel_init=nnx.with_partitioning(nnx.initializers.kaiming_normal(), PartitionSpec(None, "model")),
            rngs=rngs
        )
        self.Wv = nnx.Linear(
            in_features=hidden_dim,
            out_features=hidden_dim,
            use_bias=False,
            kernel_init=nnx.with_partitioning(nnx.initializers.kaiming_normal(), PartitionSpec(None, "model")),
            rngs=rngs
        )
        self.Wo = nnx.Linear(
            in_features=hidden_dim,
            out_features=hidden_dim,
            use_bias=False,
            kernel_init=nnx.with_partitioning(nnx.initializers.kaiming_normal(), PartitionSpec("model", None)),
            rngs=rngs
        )

    def __call__(self, x):
        bsz, seq_len, hidden_dim = x.shape

        q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)
        q = q.reshape(bsz, seq_len, self.num_heads, -1)
        k = k.reshape(bsz, seq_len, self.num_heads, -1)
        v = v.reshape(bsz, seq_len, self.num_heads, -1)
        q = q.transpose(0, 2, 1, 3)
        k = k.transpose(0, 2, 1, 3)
        v = v.transpose(0, 2, 1, 3)

        q *= 1/ jnp.sqrt(hidden_dim)
        scores = jnp.einsum("bhqd, bhkd -> bhqk", q, k)
        scores = jax.nn.softmax(scores, axis=-1)
        out = jnp.einsum("bhqk, bhkd -> bhqd", scores, v)
        out = out.transpose(0, 2, 1, 3)
        out = out.reshape(bsz, seq_len, -1)
        return self.Wo(out)


class MLP(nnx.Module):
    def __init__(self, hidden_dim, rngs):
        self.W1 = nnx.Linear(
            in_features=hidden_dim, out_features=hidden_dim * 4, use_bias=False, kernel_init=nnx.with_partitioning(nnx.initializers.kaiming_normal(), PartitionSpec(None, "model")), rngs=rngs
        )
        self.W2 = nnx.Linear(
            in_features=hidden_dim * 4, out_features=hidden_dim, use_bias=False, kernel_init=nnx.with_partitioning(nnx.initializers.kaiming_normal(), PartitionSpec("model", None)), rngs=rngs
        )

    def __call__(self, x):
        return self.W2(nnx.relu(self.W1(x)))

class Layer(nnx.Module):
    def __init__(self, hidden_dim, num_heads, rngs):
        self.attention = Attention(hidden_dim, num_heads, rngs)
        self.mlp = MLP(hidden_dim, rngs)

    def __call__(self, x):
        x = x + self.attention(x)
        x = x + self.mlp(x)
        return x

mesh = Mesh(np.array(jax.devices()).reshape(2, 2), axis_names=("data", "model"))
rngs = nnx.Rngs(0)


with mesh:
    X = jax.random.normal(jax.random.PRNGKey(0), (2, 128, 1024))
    X = jax.device_put(X, NamedSharding(mesh, PartitionSpec("data", None)))
    model = nnx.jit(Layer(hidden_dim=1024, num_heads=16, rngs=rngs))
    y = model(X)
    jax.debug.visualize_array_sharding(y.reshape(2, -1))



# model(X)