In [24]:
import os

os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_compilation_cache"

import time
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training import train_state
from jax.sharding import PartitionSpec as P

import jaxpp.api

pipeline_yield = jaxpp.api.pipeline_enter_stage


In [25]:
mesh = jax.sharding.Mesh(np.array(jax.devices())[:4], ("fsdp",))
mesh

Mesh(device_ids=array([0, 1, 2, 3]), axis_names=('fsdp',), axis_types=(Auto,))

### Model Definition

In [26]:
import flax.linen as nn

# Define model parameters
NUM_MB = 4
BATCH_SIZE = 8
SEQ_LEN = 2 * 1024
D_MODEL = 1024
MLP_DIM = 4 * D_MODEL
NUM_LAYERS = 8


class Block(nn.Module):
    @nn.remat
    @nn.jit
    @nn.compact
    def __call__(self, x):
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=8, qkv_features=D_MODEL // 8
        )(x)
        x = x + attn_output
        x = nn.LayerNorm()(x)

        # Feed-forward network
        mlp_output = nn.Dense(features=MLP_DIM)(x)
        mlp_output = nn.gelu(mlp_output)
        mlp_output = nn.Dense(features=x.shape[-1])(mlp_output)
        x = x + mlp_output
        x = nn.LayerNorm()(x)

        return x


class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        for i in range(NUM_LAYERS):
            x = Block(name=f"block_{i}")(x)
        return x


# Initialize the model's parameters
dummy_input = jnp.ones((BATCH_SIZE, SEQ_LEN, D_MODEL))
transformer = Transformer()


def init_state(model):
    key = jax.random.PRNGKey(0)
    params = model.init(key, dummy_input)["params"]

    print("Model initialized successfully!")

    optimizer = optax.adamw(learning_rate=0.001)
    state = train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=optimizer
    )
    return state


state = init_state(transformer)


Model initialized successfully!


### JAX training step with gradient accumulation

In [27]:
def loss_fn(state, params, x, targets):
    x, targets = jax.lax.with_sharding_constraint(
        (x, targets), jax.sharding.PartitionSpec("fsdp")
    )
    predictions = state.apply_fn({"params": params}, x)
    return jnp.mean((predictions - targets) ** 2)


def train_step(state, xs, targets):
    """Trains the model for one step."""

    grad_fn = jax.value_and_grad(partial(loss_fn, state))

    def scan_fn(carry, in_):
        x, targets = in_
        loss, grads = grad_fn(state.params, x, targets)
        return jax.tree.map(jnp.add, carry, (loss, grads)), ()

    (loss_acc, grads_acc), _ = jax.lax.scan(
        scan_fn,
        jax.tree.map(jnp.zeros_like, (0, state.params)),
        (
            xs.reshape(NUM_MB, -1, *xs.shape[1:]),
            targets.reshape(NUM_MB, -1, *targets.shape[1:]),
        ),
    )

    state = state.apply_gradients(grads=grads_acc)
    return state, loss_acc


def train_with_progress(train_step, inputs, num_steps=6):
    updated_state, x, targets = inputs
    training_loss = None
    # Warmup
    updated_state, training_loss = train_step(updated_state, x, targets)
    updated_state, training_loss = train_step(updated_state, x, targets)
    jax.block_until_ready(updated_state)

    start_time = time.perf_counter()
    for i in range(2, num_steps - 1):
        updated_state, training_loss = train_step(updated_state, x, targets)

        if i % 2 == 1:
            print(f"Training loss after step {i+1}: {training_loss}")

    jax.block_until_ready(updated_state)
    end_time = time.perf_counter()

    jax.profiler.start_trace("./tutorial-trace")
    updated_state, training_loss = train_step(updated_state, x, targets)
    jax.block_until_ready(updated_state)
    jax.profiler.stop_trace()

    print(f"Final training loss: {training_loss}")
    print(f"Training took: {end_time - start_time:.2f} seconds")
    return updated_state


### Sharding specification

In [28]:
NUM_MB = 4
print("Num microbatches: ", NUM_MB)

xs = dummy_input
inputs = (state, xs, dummy_input)


# Data parallel + ZeRO 3 sharding on data.
def fsdp_in_shardings(state):
    return (
        jax.tree.map(
            lambda _: jax.NamedSharding(mesh, P("fsdp"))
            if len(getattr(_, "shape", [])) > 0
            else jax.NamedSharding(mesh, P()),
            state,
        ),
        jax.NamedSharding(mesh, P("fsdp")),
        jax.NamedSharding(mesh, P("fsdp")),
    )


in_shardings = fsdp_in_shardings(state)


Num microbatches:  4


In [29]:
# Simple SPMD training with micro-batching.
with mesh:
    jitted_train_step = jax.jit(train_step, in_shardings=in_shardings)
    compiled = jitted_train_step.lower(*inputs).compile()
sharded_inputs = jax.device_put(inputs, in_shardings)

_ = train_with_progress(compiled, sharded_inputs)

Training loss after step 4: 7.956762313842773
Final training loss: 7.927053451538086
Training took: 2.38 seconds


### Pipelined Transformer

In [30]:
class AnnotatedTransformer(nn.Module):
    num_stages: int

    @nn.compact
    def __call__(self, x):
        for i in range(NUM_LAYERS):
            x = Block(name=f"block_{i}")(x)
            if (i + 1) % (NUM_LAYERS // self.num_stages) == 0:
                x = pipeline_yield(x)  # NEW: added pipeline_yield
        return x

In [31]:
# Define the training step.
def train_step(state, xs, targets, schedule):
    """Trains the model for one step."""

    grad_fn = jax.value_and_grad(partial(loss_fn, state))

    def scan_fn(in_):
        x, targets = in_
        return grad_fn(state.params, x, targets)

    loss_acc, grads_acc = jaxpp.api.treduce(  # NEW: replaced scan with treduce
        scan_fn,
        (
            xs.reshape(NUM_MB, -1, *xs.shape[1:]),
            targets.reshape(NUM_MB, -1, *targets.shape[1:]),
        ),
        schedule=schedule,  # NEW: added schedule
        operation=jaxpp.api.Add,  # State update operation
    )

    state = state.apply_gradients(grads=grads_acc)
    return state, loss_acc

### MPMD Mesh

In [32]:
NUM_STAGES = 4
mesh = jax.sharding.Mesh(
    np.array(jax.devices())[:NUM_STAGES].reshape(NUM_STAGES, -1), ("stage", "fsdp")
)
mpmd_mesh = jaxpp.api.MpmdMesh(mesh, "stage")
mpmd_mesh

MpmdMesh(jax_mesh=Mesh(device_ids=array([[0],
       [1],
       [2],
       [3]]), axis_names=('stage', 'fsdp'), axis_types=(Auto, Auto)), mpmd_axis_name='stage')

In [33]:
mpmd_mesh.unstack

[Mesh(device_ids=array([[0]]), axis_names=('stage', 'fsdp'), axis_types=(Auto, Auto)),
 Mesh(device_ids=array([[1]]), axis_names=('stage', 'fsdp'), axis_types=(Auto, Auto)),
 Mesh(device_ids=array([[2]]), axis_names=('stage', 'fsdp'), axis_types=(Auto, Auto)),
 Mesh(device_ids=array([[3]]), axis_names=('stage', 'fsdp'), axis_types=(Auto, Auto))]

In [34]:
ann_transformer = AnnotatedTransformer(num_stages=NUM_STAGES)
ann_state = init_state(ann_transformer)
in_shardings = fsdp_in_shardings(ann_state)
inputs = (ann_state, *inputs[1:])

mpmd_in_specs = jax.tree.map(lambda _: _.spec, in_shardings)
pp_train_step = jaxpp.api.mpmd_jit_with_loop(
    partial(
        train_step,
        schedule=jaxpp.api.Interleaved1F1B(num_stages=NUM_STAGES, mpmd_dim=NUM_STAGES),
    ),
    in_specs=mpmd_in_specs,
    out_specs=(mpmd_in_specs[0], jax.sharding.PartitionSpec()),
    mpmd_mesh=mpmd_mesh,
).compile(*inputs)

with mesh:
    updated_state = train_with_progress(pp_train_step, inputs)

Model initialized successfully!


Training loss after step 4: 7.956801891326904
Final training loss: 7.922821521759033
Training took: 1.70 seconds


In [35]:
print("layer_0", updated_state.params["block_0"]["Dense_0"]["kernel"])
print("layer_1", updated_state.params["block_1"]["Dense_0"]["kernel"])
print("layer_2", updated_state.params["block_2"]["Dense_0"]["kernel"])
print("layer_3", updated_state.params["block_3"]["Dense_0"]["kernel"])
print("layer_4", updated_state.params["block_4"]["Dense_0"]["kernel"])
print("layer_5", updated_state.params["block_5"]["Dense_0"]["kernel"])
print("layer_6", updated_state.params["block_6"]["Dense_0"]["kernel"])
print("layer_7", updated_state.params["block_7"]["Dense_0"]["kernel"])

layer_0 MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(0,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
layer_1 MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(0,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
layer_2 MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(1,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
layer_3 MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(1,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
layer_4 MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(2,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=devi

### Schedule specification

In [36]:
MPMD_RANK = 2
[
    str(e)
    for e in jaxpp.api.Interleaved1F1B(
        num_stages=NUM_STAGES, mpmd_dim=NUM_STAGES
    ).tasks(NUM_MB)[MPMD_RANK]
]

['fwd_2__0',
 'fwd_2__1',
 'fwd_2__2',
 'bwd_2__0',
 'fwd_2__3',
 'bwd_2__1',
 'bwd_2__2',
 'bwd_2__3']

### Tasks can be fused

In [37]:
MPMD_RANK = 3
[
    str(e)
    for e in jaxpp.api.Interleaved1F1B(
        num_stages=NUM_STAGES, mpmd_dim=NUM_STAGES
    ).tasks(NUM_MB)[MPMD_RANK]
]

['FusedTask(fwd_3__0, bwd_3__0)',
 'FusedTask(fwd_3__1, bwd_3__1)',
 'FusedTask(fwd_3__2, bwd_3__2)',
 'FusedTask(fwd_3__3, bwd_3__3)']

In [38]:
jaxpr_lines = str(pp_train_step.closed_jaxpr).splitlines()

In [39]:
for idx, l in enumerate(jaxpr_lines):
    if "task_name=fwd_0" in l:
        break
print("\n".join(jaxpr_lines[idx - 10 : idx + 50]))

      donate_invars=(False,)
      task_name=before_loop_10_3
    ] dah
    dfl[35m:f32[2,2048,1024][39m dfm[35m:f32[2,2048,1024][39m dfn[35m:f32[2,2048,1024][39m dfo[35m:i32[][39m = task[
      task_info=(0, TaskType.FWD)
      call_jaxpr=jaxpr
      mpmd_idx=0
      latency=1
      call_counter=0
      donate_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
      task_name=fwd_0
    ] dai dbp cll clm cln clo clp clq clr cls clt clu clv clw clx cly clz cma cmb
      cmc cmd cme cmf cmg cmh cmi cmj cmk cml cmm cmn cmo cmp cmq
    _[35m:i32[][39m = delete[mpmd_idx=0] dbp
    dfp[35m:f32[2,2048,1024][39m = transfer[
      src_mpmd_idx=0
      src_shardings=[NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device)]
      tgt_m

In [40]:
[
    str(e)
    for e in jaxpp.api.Interleaved1F1B(
        num_stages=NUM_LAYERS, mpmd_dim=NUM_STAGES
    ).tasks(NUM_MB)[0]
]

['fwd_0__0',
 'fwd_0__1',
 'fwd_0__2',
 'fwd_0__3',
 'fwd_4__0',
 'fwd_4__1',
 'fwd_4__2',
 'fwd_4__3',
 'bwd_4__0',
 'bwd_4__1',
 'bwd_4__2',
 'bwd_4__3',
 'bwd_0__0',
 'bwd_0__1',
 'bwd_0__2',
 'bwd_0__3']

In [41]:
[
    str(e)
    for e in jaxpp.api.Interleaved1F1B(
        num_stages=NUM_LAYERS, mpmd_dim=NUM_STAGES
    ).tasks(NUM_MB)[3]
]

['fwd_3__0',
 'fwd_3__1',
 'fwd_3__2',
 'fwd_3__3',
 'FusedTask(fwd_7__0, bwd_7__0)',
 'FusedTask(fwd_7__1, bwd_7__1)',
 'FusedTask(fwd_7__2, bwd_7__2)',
 'FusedTask(fwd_7__3, bwd_7__3)',
 'bwd_3__0',
 'bwd_3__1',
 'bwd_3__2',
 'bwd_3__3']

In [42]:
ann_transformer.num_stages = NUM_LAYERS
pp_train_step_interleaved = jaxpp.api.mpmd_jit_with_loop(
    partial(
        train_step,
        schedule=jaxpp.api.Interleaved1F1B(num_stages=NUM_LAYERS, mpmd_dim=NUM_STAGES),
    ),
    in_specs=mpmd_in_specs,
    out_specs=(mpmd_in_specs[0], jax.sharding.PartitionSpec()),
    mpmd_mesh=mpmd_mesh,
).compile(*inputs)

In [43]:
with mesh:
    updated_state_interleaved = train_with_progress(pp_train_step_interleaved, inputs)

Training loss after step 4: 7.956801891326904
Final training loss: 7.922821521759033
Training took: 1.35 seconds


In [44]:
print(updated_state_interleaved.step)
print(updated_state_interleaved.params["block_0"]["Dense_0"]["kernel"])
print(updated_state_interleaved.params["block_1"]["Dense_0"]["kernel"])
print(updated_state_interleaved.params["block_2"]["Dense_0"]["kernel"])
print(updated_state_interleaved.params["block_3"]["Dense_0"]["kernel"])
print(updated_state_interleaved.params["block_4"]["Dense_0"]["kernel"])
print(updated_state_interleaved.params["block_5"]["Dense_0"]["kernel"])
print(updated_state_interleaved.params["block_6"]["Dense_0"]["kernel"])
print(updated_state_interleaved.params["block_7"]["Dense_0"]["kernel"])

MpmdArray(shape=(), dtype=int32, mpmd_idxs=(0, 1, 2, 3), sharding=NamedSharding(mesh=Mesh('stage': 4, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device))
MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(0,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(1,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(2,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
MpmdArray(shape=(1024, 4096), dtype=float32, mpmd_idxs=(3,), sharding=NamedSharding(mesh=Mesh('stage': 1, 'fsdp': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('fsdp',), memory_kind=device))
MpmdArray(shape=(1024, 4096), dtype=float32, m

In [45]:
updated_state_interleaved.step._partially_addressable_arrays

OrderedDict([(0, Array(6, dtype=int32, weak_type=True)),
             (1, Array(6, dtype=int32, weak_type=True)),
             (2, Array(6, dtype=int32, weak_type=True)),
             (3, Array(6, dtype=int32, weak_type=True))])

In [46]:
[a.sharding.mesh.devices for a in updated_state_interleaved.step._partially_addressable_arrays.values()]

[array([[CudaDevice(id=0)]], dtype=object),
 array([[CudaDevice(id=1)]], dtype=object),
 array([[CudaDevice(id=2)]], dtype=object),
 array([[CudaDevice(id=3)]], dtype=object)]