# Implementing Data Parallelism in JAX

In [1]:
import argparse
from dataclasses import dataclass

import numpy as np
import jax
import jax.numpy as jnp
from jax import random, lax
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from flax import linen as nn
from flax.training.train_state import TrainState
import optax



In [2]:
@dataclass
class TrainCfg:
    vocab_size: int = 512
    seq_len: int = 64
    d_model: int = 256
    n_heads: int = 8
    n_layers: int = 4
    d_ff: int = 1024
    batch_size: int = 32          # global batch (will be sharded across devices)
    steps: int = 50
    lr: float = 3e-4
    seed: int = 42
    dataset_size: int = 8192      # total tokens sequences

In [3]:

def make_toy_dataset(cfg: TrainCfg, rng_seed: int = 1234):
    rs = np.random.RandomState(rng_seed)
    data = rs.randint(0, cfg.vocab_size, size=(cfg.dataset_size, cfg.seq_len), dtype=np.int32)
    return data  # np.array on host (we'll feed batches each step)

def batch_iter(data: np.ndarray, batch_size: int):
    n = (len(data) // batch_size) * batch_size
    for i in range(0, n, batch_size):
        yield data[i:i+batch_size]

In [5]:
def sinusoidal_positional_embedding(max_len, d_model):
    pos = jnp.arange(max_len)[:, None]
    i = jnp.arange(0, d_model, 2)[None, :]
    inv_freq = jnp.exp(-(jnp.log(10000.0) / d_model) * i)
    pe = jnp.zeros((max_len, d_model), dtype=jnp.float32)
    pe = pe.at[:, 0::2].set(jnp.sin(pos * inv_freq))
    pe = pe.at[:, 1::2].set(jnp.cos(pos * inv_freq))
    return pe

class Block(nn.Module):
    d_model: int
    n_heads: int
    d_ff: int

    @nn.compact
    def __call__(self, x, *, deterministic: bool = True):
        y = nn.LayerNorm()(x)
        y = lax.with_sharding_constraint(y, P("data", None))
        y = nn.SelfAttention(
            num_heads=self.n_heads,
            qkv_features=self.d_model,
            use_bias=False,
            dropout_rate=0.0,
            deterministic=deterministic,
            decode=False,
            broadcast_dropout=False,
            dtype=jnp.float32,
        )(y)
        x = x + y
        y = nn.LayerNorm()(x)
        y = nn.gelu(nn.Dense(self.d_ff)(y))
        y = nn.Dense(self.d_model)(y)
        return x + y

class TinyTransformerLM(nn.Module):
    vocab_size: int
    d_model: int
    n_heads: int
    n_layers: int
    d_ff: int
    max_seq: int

    def setup(self):
        self.tok_emb = nn.Embed(self.vocab_size, self.d_model)
        self.pos_emb = self.param(
            "pos_emb",
            lambda key: sinusoidal_positional_embedding(self.max_seq, self.d_model),
        )
        self.blocks = [Block(self.d_model, self.n_heads, self.d_ff) for _ in range(self.n_layers)]
        self.ln_f = nn.LayerNorm()
        self.head = nn.Dense(self.vocab_size, use_bias=False)

    def __call__(self, tokens, *, deterministic: bool = True):
        # tokens: [B, T]
        B, T = tokens.shape
        x = self.tok_emb(tokens) + self.pos_emb[:T, :][None, ...]
        for blk in self.blocks:
            x = blk(x, deterministic=deterministic)
        x = self.ln_f(x)
        return self.head(x)  # [B, T, V]

## Device Mesh and Sharding

In [6]:

def make_mesh():
    devices = np.array(jax.devices())
    if devices.size < 2:
        raise RuntimeError("Need at least 2 devices for data parallelism.")
    return Mesh(devices, ("data",))

def replicate_tree_to_mesh(tree, mesh):
    return jax.tree_util.tree_map(
        lambda x: jax.device_put(x, NamedSharding(mesh, P())),  # P() => replicate across 'data'
        tree
    )

def shard_with_fsdp(tree, mesh):
    def shard_array(x):
        shape = x.shape
        if len(shape) == 0:
            return jax.device_put(x, NamedSharding(mesh, P()))
        biggest_axis = int(np.argmax(shape))
        spec = [None] * len(shape)
        spec[biggest_axis] = "data"
        return jax.device_put(x, NamedSharding(mesh, P(*spec)))

    return jax.tree_util.tree_map(shard_array, tree)

def shard_batch_to_mesh(batch_np, mesh, seq_len):
    arr = jnp.asarray(batch_np, dtype=jnp.int32)
    return jax.device_put(arr, NamedSharding(mesh, P("data", None)))  # [B, T] -> shard B across 'data'

## Training

In [7]:

def lm_loss_from_logits(logits, tokens):
    # next-token prediction
    targets = tokens[:, 1:]                   # [B, T-1]
    logit_slice = logits[:, :-1, :]           # [B, T-1, V]
    V = logit_slice.shape[-1]
    onehot = jax.nn.one_hot(targets, V, dtype=jnp.float32)
    logp = jax.nn.log_softmax(logit_slice, axis=-1)
    # mean over batch and time
    return -(onehot * logp).sum(-1).mean()

def create_train_state(rng, cfg: TrainCfg):
    model = TinyTransformerLM(
        vocab_size=cfg.vocab_size, d_model=cfg.d_model, n_heads=cfg.n_heads,
        n_layers=cfg.n_layers, d_ff=cfg.d_ff, max_seq=cfg.seq_len
    )
    # Mock input for shape inference
    dummy = jnp.zeros((8, cfg.seq_len), dtype=jnp.int32)
    params = model.init(rng, dummy, deterministic=True)["params"]
    tx = optax.adamw(cfg.lr)
    return TrainState(step=0, apply_fn=model.apply, params=params, tx=tx, opt_state=tx.init(params)), model


def build_step_fn(cfg: TrainCfg, model: TinyTransformerLM):
    def step(state: TrainState, tokens_sharded: jnp.ndarray):
        tokens_local = lax.with_sharding_constraint(tokens_sharded, P("data", None))

        def loss_fn(params):
            logits = state.apply_fn({"params": params}, tokens_local, deterministic=True)
            loss = lm_loss_from_logits(logits, tokens_local)
            return loss, logits

        (loss, _), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)

        new_state = state.apply_gradients(grads=grads)
        return new_state, loss

    step_jit = jax.jit(step, donate_argnums=(0,))
    return step_jit

def run(cfg: TrainCfg):
    rng = random.PRNGKey(cfg.seed)
    data = make_toy_dataset(cfg, rng_seed=1234)
    mesh = make_mesh()

    with jax.set_mesh(mesh):
        state, model = create_train_state(rng, cfg)    

    with jax.set_mesh(mesh):
        # For FSDP, it's as simple as annotating params and opt_state
        # state = state.replace(
        #     params=shard_with_fsdp(state.params, mesh),
        #     opt_state=shard_with_fsdp(state.opt_state, mesh),
        # )
        state = state.replace(
            params=replicate_tree_to_mesh(state.params, mesh),
            opt_state=replicate_tree_to_mesh(state.opt_state, mesh),
        )
        step_jit = build_step_fn(cfg, model)

        # Training
        it = batch_iter(data, cfg.batch_size)
        for s in range(cfg.steps):
            batch = next(it)
            tokens_sharded = shard_batch_to_mesh(batch, mesh, cfg.seq_len)

            state, loss = step_jit(state, tokens_sharded)
            if s % 10 == 0 or s == cfg.steps - 1:
                print(f"[jax-dp] step {s:04d}  loss {float(loss):.4f}")

    params_host = jax.tree_util.tree_map(lambda x: np.array(x.addressable_data(0)), state.params)
    return params_host


def train(steps: int = 50, batch_size: int = 32, seq_len: int = 64, vocab_size: int = 512, seed: int = 42):
    cfg = TrainCfg(
        vocab_size=vocab_size,
        seq_len=seq_len,
        batch_size=batch_size,
        steps=steps,
        seed=seed,
    )
    run(cfg)

train()

[jax-dp] step 0000  loss 6.7112
[jax-dp] step 0010  loss 6.3130
[jax-dp] step 0020  loss 6.2615
[jax-dp] step 0030  loss 6.2571
[jax-dp] step 0040  loss 6.2508
[jax-dp] step 0049  loss 6.2567
