# Code adapted from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/scaling/JAX/pipeline_parallel_simple.html

## Imports

In [3]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "data_parallel.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

Downloading https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/data_parallel.py...


In [5]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

In [7]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core.frozen_dict import FrozenDict
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict

# Helper types
PyTree = Any
Parameter = jax.Array | nn.Partitioned
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [9]:
from data_parallel import fold_rng_over_axis, sync_gradients
from single_gpu import (
    Batch,
    TrainState,
    accumulate_gradients,
    get_num_params,
    print_metrics,
)

## Pipeline Parallelism with Micro-Batching

In [12]:
class MLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        input_features = x.shape[-1]
        residual = x
        x = nn.LayerNorm(dtype=self.config.dtype, name="pre_norm")(x)
        x = nn.Dense(
            features=self.config.hidden_size * self.config.mlp_expansion,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not self.train)(x)
        x = nn.Dense(features=input_features, dtype=self.config.dtype, name="output_dense")(x)
        return x + residual

In [14]:
class MLPLayers(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        # Scan version
        block_class = MLPBlock
        if "MLP" in self.config.remat:
            block_class = nn.remat(block_class, prevent_cse=False)
        block = block_class(config=self.config, train=self.train, name="block")
        x, _ = nn.scan(
            lambda module, carry, _: (module(carry), ()),
            variable_axes={"params": 0},
            split_rngs={"params": True, "dropout": True},
            length=self.config.num_layers,
        )(block, x, ())
        # Non-scanned version
        # for i in range(self.config.num_layers):
        #     x = block_class(self.config, train=train, name=f"block_{i}")(x)
        return x

In [16]:
def stack_params(
    params: PyTree, axis_name: str, axis: int = 0, mask_except: jax.Array | int | None = None
) -> PyTree:
    """Stacks sharded parameters along a given axis name.

    Args:
        params: PyTree of parameters.
        axis_name: Name of the axis to stack along.
        axis: Index of the axis to stack along.
        mask_except: If not None, only the `mask_except`-th shard will be non-zero.

    Returns:
        PyTree of parameters with the same structure as `params`, but with the leaf
        nodes replaced by `nn.Partitioned` objects with sharding over axis name added
        to `axis`-th axis of parameters.
    """

    def _stack(x: Parameter) -> Parameter:
        if isinstance(x, nn.Partitioned):
            value, names = x.value, x.names
        else:
            value, names = x, (None,) * x.ndim
        if mask_except is not None:
            axis_index = jax.lax.axis_index(axis_name)
            value = jnp.where(axis_index == mask_except, value, 0.0)
        value = jnp.expand_dims(value, axis)
        names = names[:axis] + (axis_name,) + names[axis:]
        return nn.Partitioned(value, names=names)

    return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))


def unstack_params(params: PyTree, axis_name: str) -> PyTree:
    """Unstacks parameters along a given axis name.

    Inverse operation to `stack_params`.

    Args:
        params: PyTree of parameters.
        axis_name: Name of the axis to unstack along.

    Returns:
        PyTree of parameters with the same structure as `params`, but
        with the leaf nodes having the sharding over the axis name removed.
    """

    def _unstack(x: Parameter) -> Parameter:
        if isinstance(x, nn.Partitioned) and axis_name in x.names:
            value = x.value
            names = x.names
            axis_idx = names.index(axis_name)
            value = value.squeeze(axis_idx)
            names = names[:axis_idx] + names[axis_idx + 1 :]
            if all([n is None for n in names]):
                return value
            else:
                return nn.Partitioned(value, names=names)
        else:
            return x

    return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))

In [40]:
def execute_pipeline_step(
    module: nn.Module,
    state: jax.Array,
    input: jax.Array,
    *args,
    model_axis_name: str,
    **kwargs,
) -> Tuple[jax.Array, jax.Array]:
    """Single micro-batch pipeline step.

    Args:
        module: Flax module representing the stage to execute.
        state: Last communicated features between stages. Used as input to the module for all stages except the first.
        input: Original micro-batch input to the pipeline stage. Used as input to the module for the first stage.
        *args: Additional arguments to the module.
        model_axis_name: Name of the model axis in the mesh/shard_map.
        **kwargs: Additional keyword arguments to the module.

    Returns:
        Tuple of the new state (after communication) and the output of the module.
    """
    num_stages = jax.lax.psum(1, model_axis_name)
    stage_index = jax.lax.axis_index(model_axis_name)
    # For the first stage, we use the microbatches as input.
    # For all other stages, we use the last state from the
    # previous stage as input.
    state = jnp.where(stage_index == 0, input, state)
    state = module(state, *args, **kwargs)
    # For the last stage, we return the state as output.
    # For all other stages, we return zeros.
    output = jnp.where(
        stage_index == num_stages - 1,
        state,
        jnp.zeros_like(state),
    )
    # Communicate the last state to the next stage.
    state = jax.lax.ppermute(
        state,
        model_axis_name,
        perm=[(i, (i + 1) % num_stages) for i in range(num_stages)],
    )
    return (state, output)

In [36]:
@jax.named_scope("pipeline")  # Naming scope for profiling.
def execute_pipeline(
    module: nn.Module, x: jax.Array, *args, num_microbatches: int, model_axis_name: str, **kwargs
) -> jax.Array:
    """Execute a pipeline of stages on a batch of data.

    Uses the principle of GPipe in splitting the batch into micro-batches
    and running the pipeline stages in parallel.

    Args:
        module: Flax module representing the pipeline stage to execute.
        x: Batch of input data, only needed on device of the first stage. Data will be split into micro-batches.
        *args: Additional arguments to the module.
        num_microbatches: Number of micro-batches to split the batch into.
        model_axis_name: Name of the model axis in the mesh/shard_map.
        **kwargs: Additional keyword arguments to the module.

    Returns:
        Output of the last stage of the pipeline. For devices that are not
        the last stage, the output is zeros.
    """
    num_stages = jax.lax.psum(1, model_axis_name)
    # Structure the input data into micro-batches.
    batch_size = x.shape[0]
    assert (
        batch_size % num_microbatches == 0
    ), f"Batch size {batch_size} must be divisible by number of microbatches {num_microbatches}"
    microbatch_size = batch_size // num_microbatches
    microbatches = jnp.reshape(x, (num_microbatches, microbatch_size, *x.shape[1:]))
    inputs = jnp.concatenate(  # Add zeros for unused computation blocks in first stage.
        [
            microbatches,
            jnp.zeros((num_stages - 1, *microbatches.shape[1:]), dtype=microbatches.dtype),
        ],
        axis=0,
    )
    state = jnp.zeros_like(microbatches[0])
    num_iterations = inputs.shape[0]
    # Run loop over pipeline steps.
    _, outputs = nn.scan(
        functools.partial(
            execute_pipeline_step,
            *args,
            model_axis_name=model_axis_name,
            **kwargs,
        ),
        variable_broadcast={"params": True},
        split_rngs={"params": False, "dropout": True},
        length=num_iterations,
        in_axes=0,
        out_axes=0,
    )(module, state, inputs)
    # Take last N outputs (first ones are zeros from unused computation blocks in last stage).
    outputs = jnp.concatenate(outputs[-num_microbatches:], axis=0)
    return outputs

In [22]:
class PipelineModule(nn.Module):
    model_axis_name: str
    num_microbatches: int
    module_fn: Callable[..., nn.Module]

    @nn.compact
    def __call__(self, *args, **kwargs):
        module = self.module_fn()
        return execute_pipeline(
            module,
            *args,
            **kwargs,
            num_microbatches=self.num_microbatches,
            model_axis_name=self.model_axis_name,
        )

In [24]:
class ModelParallelismWrapper(nn.Module):
    """Wrapper for adding model parallelism to a module.

    This wrapper adds sharding over the model axis to the parameters of the module
    and initializes the module with different parameters across the model axis.

    Args:
        model_axis_name: Name of the model axis to shard over.
        module_fn: Function that returns the Flax module to wrap.
        mask_except_model_idx: If not None, only the `mask_except_model_idx`-th shard will be non-zero.
        split_rngs: If True, split the random number generators across the model axis.
        module_kwargs: Additional keyword arguments to pass to the module function.
    """

    model_axis_name: str
    module_fn: Callable[..., nn.Module]
    mask_except_model_idx: int | None = None
    split_rngs: bool = True
    module_kwargs: FrozenDict[str, Any] = FrozenDict({})

    @nn.compact
    def __call__(self, *args, **kwargs):
        if self.is_initializing() and self.split_rngs:
            # Initialize each module across the model axis with different parameters.
            self.scope.rngs["params"] = self.scope.rngs["params"].replace(
                rng=fold_rng_over_axis(self.scope.rngs["params"].rng, self.model_axis_name)
            )
        # Wrap variables in nn.Partitioned objects to add sharding over the model axis.
        module = nn.map_variables(
            target=functools.partial(
                self.module_fn,
                name="sharded",
                **self.module_kwargs,
            ),
            trans_in_fn=functools.partial(unstack_params, axis_name=self.model_axis_name),
            trans_out_fn=functools.partial(
                stack_params,
                axis_name=self.model_axis_name,
                mask_except=self.mask_except_model_idx,
            ),
            mapped_collections="params",
            mutable=True,
        )()
        return module(
            *args,
            **kwargs,
        )

In [26]:
class PPClassifier(nn.Module):
    config: ConfigDict
    pipeline_module_class: Callable[..., nn.Module] = PipelineModule

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        # Input layer. Only needed in the first stage.
        x = ModelParallelismWrapper(
            module_fn=functools.partial(
                nn.Dense,
                features=self.config.hidden_size,
                dtype=self.config.dtype,
            ),
            model_axis_name=self.config.model_axis_name,
            mask_except_model_idx=0,
            name="input_dense",
        )(x)
        # Pipeline
        stage_module_fn = functools.partial(
            MLPLayers, config=self.config, train=train, name="mlp_layers"
        )
        pipeline_module_fn = functools.partial(
            self.pipeline_module_class,
            model_axis_name=self.config.model_axis_name,
            num_microbatches=self.config.num_microbatches,
            module_fn=stage_module_fn,
        )
        module = ModelParallelismWrapper(
            module_fn=pipeline_module_fn,
            model_axis_name=self.config.model_axis_name,
            name="pipeline",
        )
        x = module(x)
        # Output layer. Only needed in the last stage.
        output_wrapper = functools.partial(
            ModelParallelismWrapper,
            model_axis_name=self.config.model_axis_name,
            mask_except_model_idx=self.config.model_axis_size - 1,
        )
        x = output_wrapper(
            module_fn=functools.partial(nn.LayerNorm, dtype=self.config.dtype), name="output_norm"
        )(x)
        x = output_wrapper(
            module_fn=functools.partial(
                nn.Dense, features=self.config.num_classes, dtype=self.config.dtype
            ),
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [28]:
data_config = ConfigDict(
    dict(
        batch_size=128,
        num_classes=10,
        input_size=784,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=512,
        mlp_expansion=1,
        dropout_rate=0.1,
        num_layers=8,
        dtype=jnp.float32,
        num_classes=data_config.num_classes,
        remat=(),
        data_axis_name="data",
        model_axis_name="model",
        model_axis_size=4,
        num_microbatches=8,
    )
)
model_config.num_layers //= model_config.model_axis_size  # Layers distributed over model axis.
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=1,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        model_axis_name=model_config.model_axis_name,
        model_axis_size=model_config.model_axis_size,
        seed=42,
    )
)

In [30]:
device_array = np.array(jax.devices()).reshape(-1, config.model_axis_size)
mesh = Mesh(device_array, (config.data_axis_name, config.model_axis_name))
model_pp = PPClassifier(config=model_config)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)
rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
    labels=jax.random.randint(
        data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
    ),
)

In [32]:
def init_fn(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
    init_rng, rng = jax.random.split(rng)
    variables = model.init({"params": init_rng}, x, train=False)
    params = variables.pop("params")
    state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        rng=rng,
    )
    return state

In [42]:
init_pp_fn = shard_map(
    functools.partial(init_fn, model=model_pp),
    mesh,
    in_specs=(P(), P(config.data_axis_name)),
    out_specs=P(),
    check_rep=False,
)
state_pp_shapes = jax.eval_shape(init_pp_fn, model_init_rng, batch.inputs)
state_pp_specs = nn.get_partition_spec(state_pp_shapes)

  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))


In [44]:
pprint(state_pp_specs.params)

{'input_dense': {'sharded': {'bias': PartitionSpec('model', None),
                             'kernel': PartitionSpec('model', None, None)}},
 'output_dense': {'sharded': {'bias': PartitionSpec('model', None),
                              'kernel': PartitionSpec('model', None, None)}},
 'output_norm': {'sharded': {'bias': PartitionSpec('model', None),
                             'scale': PartitionSpec('model', None)}},
 'pipeline': {'sharded': {'mlp_layers': {'block': {'input_dense': {'bias': PartitionSpec('model', None, None),
                                                                   'kernel': PartitionSpec('model', None, None, None)},
                                                   'output_dense': {'bias': PartitionSpec('model', None, None),
                                                                    'kernel': PartitionSpec('model', None, None, None)},
                                                   'pre_norm': {'bias': PartitionSpec('model', None, None),
 

In [46]:
init_pp_fn = jax.jit(
    shard_map(
        functools.partial(init_fn, model=model_pp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=state_pp_specs,
        check_rep=False,
    ),
)
state_pp = init_pp_fn(model_init_rng, batch.inputs)

  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))


In [48]:
pprint(
    jax.tree_map(lambda x: x.shape, state_pp.params["pipeline"]["sharded"]["mlp_layers"]["block"])
)

{'input_dense': {'bias': Partitioned(value=(4, 2, 512),
                                     names=('model', None, None),
                                     mesh=None),
                 'kernel': Partitioned(value=(4, 2, 512, 512),
                                       names=('model', None, None, None),
                                       mesh=None)},
 'output_dense': {'bias': Partitioned(value=(4, 2, 512),
                                      names=('model', None, None),
                                      mesh=None),
                  'kernel': Partitioned(value=(4, 2, 512, 512),
                                        names=('model', None, None, None),
                                        mesh=None)},
 'pre_norm': {'bias': Partitioned(value=(4, 2, 512),
                                  names=('model', None, None),
                                  mesh=None),
              'scale': Partitioned(value=(4, 2, 512),
                                   names=('model', None, N

  jax.tree_map(lambda x: x.shape, state_pp.params["pipeline"]["sharded"]["mlp_layers"]["block"])


In [50]:
pprint(
    state_pp.params["pipeline"]["sharded"]["mlp_layers"]["block"]["input_dense"]["kernel"].value[
        :, :, 0, 0
    ]
)

Array([[ 0.01044598, -0.07416785],
       [-0.04605146,  0.0008348 ],
       [-0.00904123, -0.00018691],
       [ 0.00661926, -0.06117292]], dtype=float32)


In [52]:
print("Input Layer")
pprint(state_pp.params["input_dense"]["sharded"]["kernel"].value[:, 0, 0])
print("\nOutput layer")
pprint(state_pp.params["output_dense"]["sharded"]["kernel"].value[:, 0, 0])


Input Layer
Array([-0.0754908,  0.       ,  0.       ,  0.       ], dtype=float32)

Output layer
Array([ 0.        ,  0.        ,  0.        , -0.07138917], dtype=float32)


In [54]:
def loss_fn(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[jax.Array, Dict[str, Any]]:
    # Since dropout masks vary across the batch dimension, we want each device to generate a
    # different mask. We can achieve this by folding the rng over the data axis, so that each
    # device gets a different rng and thus mask.
    dropout_rng = fold_rng_over_axis(rng, (config.data_axis_name, config.model_axis_name))
    # Remaining computation is the same as before for single device.
    logits = apply_fn(
        {"params": params},
        batch.inputs,
        train=True,
        rngs={"dropout": dropout_rng},
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = batch.inputs.shape[0]
    # Mask out loss and accuracy for pipeline stages except last one.
    model_idx = jax.lax.axis_index(config.model_axis_name)
    model_size = jax.lax.psum(1, config.model_axis_name)
    loss = jnp.where(model_idx != model_size - 1, 0.0, loss)
    correct_pred = jnp.where(model_idx != model_size - 1, False, correct_pred)
    batch_size = jnp.where(model_idx != model_size - 1, 0, batch_size)
    # Collect metrics and return loss.
    step_metrics = {
        "loss": (loss.sum(), batch_size),
        "accuracy": (correct_pred.sum(), batch_size),
    }
    loss = loss.mean()
    return loss, step_metrics

In [56]:
def train_step_pp(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=loss_fn,
    )
    # Update parameters. We need to sync the gradients across data devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = sync_gradients(grads, (config.data_axis_name, config.model_axis_name))
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas (both model and data axes).
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=(config.data_axis_name, config.model_axis_name)),
            step_metrics,
        )
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [58]:
train_step_pp_fn = jax.jit(
    shard_map(
        train_step_pp,
        mesh,
        in_specs=(state_pp_specs, P(), P(config.data_axis_name)),
        out_specs=(state_pp_specs, P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)
_, metric_shapes = jax.eval_shape(
    train_step_pp_fn,
    state_pp,
    None,
    batch,
)
metrics_pp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_pp, metrics_pp = train_step_pp_fn(state_pp, metrics_pp, batch)

  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  step_metrics = jax.tree_map(
  metrics_pp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
  return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))
 

In [60]:
print(f"Number of parameters: {get_num_params(state_pp):_}")

Number of parameters: 5_842_984


In [62]:
for _ in range(15):
    state_pp, metrics_pp = train_step_pp_fn(state_pp, metrics_pp, batch)
final_metrics_pp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_pp, final_metrics_pp = train_step_pp_fn(state_pp, final_metrics_pp, batch)
print_metrics(final_metrics_pp, title="Final Metrics - Pipeline")

## Pipeline Parallelism with Looping

In [65]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["pipeline_parallel.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

Downloading https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/pipeline_parallel.py...


In [71]:
from pipeline_parallel import (
    PPClassifier,
    get_default_pp_classifier_config,
    train_pipeline_model,
    train_step_pp,
)
from flax.struct import dataclass

In [73]:
@dataclass
class PipelineState:
    inputs: jax.Array
    outputs: jax.Array
    input_indices: jax.Array
    output_indices: jax.Array
    update_indices: jax.Array
    params_indices: jax.Array
    last_state: jax.Array
    rngs: PyTree

In [75]:
def execute_looping_pipeline_step(
    index: jax.Array | int,
    state: PipelineState,
    *args,
    module: nn.Module,
    params: PyTree,
    model_axis_name: str,
    **kwargs,
) -> PipelineState:
    """Single micro-batch pipeline step with loopback communication.

    Args:
        index: Pipeline step index (between 0 and num_loops * num_microbatches + num_stages - 2).
        state: State of the pipeline, including indices for controlling the execution.
        *args: Additional arguments to the module.
        module: Flax module representing the stage layer to execute.
        params: PyTree of parameters. The params for all layers should be stacked along the first axis.
        model_axis_name: Name of the model axis in the mesh/shard_map.
        **kwargs: Additional keyword arguments to the module.

    Returns:
        New state of the pipeline after the execution of the pipeline step, with potentially updated
        inputs, outputs, rngs, and last_state arrays.
    """
    num_stages = jax.lax.psum(1, model_axis_name)
    input_index = state.input_indices[index]
    output_index = state.output_indices[index]
    update_index = state.update_indices[index]
    params_index = state.params_indices[index]
    # Update inputs with last state. If update_index is -1, do not update.
    # This is used to buffer the communications back to first stage.
    clipped_update_index = jnp.clip(update_index, 0, state.inputs.shape[0] - 1)
    inputs = jax.lax.dynamic_update_index_in_dim(
        state.inputs,
        jnp.where(update_index >= 0, state.last_state, state.inputs[clipped_update_index]),
        clipped_update_index,
        axis=0,
    )
    # Select input of the current stage. For all stages except the first stage,
    # the input is the last output of the previous stage (i.e. last_state).
    step_input = jnp.where(
        input_index >= 0,
        inputs[input_index],
        state.last_state,
    )
    # Apply the module to the input. Select the right set of parameters based
    # on the loop index.
    rngs = jax.tree_map(lambda rng: jax.random.split(rng, 2), state.rngs)
    rngs, step_rngs = jax.tree_map(lambda x: x[0], rngs), jax.tree_map(lambda x: x[1], rngs)
    params = jax.tree_map(lambda x: x[params_index], params)
    output = module.apply(params, step_input, *args, **kwargs, rngs=step_rngs)
    # Update outputs with the output of the current stage. If output_index is -1,
    # do not update. This is used to buffer the final outputs of the last stage.
    clipped_output_index = jnp.clip(output_index, 0, state.outputs.shape[0] - 1)
    outputs = jax.lax.dynamic_update_index_in_dim(
        state.outputs,
        jnp.where(output_index >= 0, output, state.outputs[clipped_output_index]),
        clipped_output_index,
        axis=0,
    )
    # Communicate the last output to the next stage.
    last_state = jax.lax.ppermute(
        output,
        model_axis_name,
        perm=[(i, (i + 1) % num_stages) for i in range(num_stages)],
    )
    return state.replace(
        inputs=inputs,
        outputs=outputs,
        last_state=last_state,
        rngs=rngs,
    )

In [77]:
def prepare_looping_pipeline_indices(
    num_loops: int, num_microbatches: int, num_stages: int, stage_index: jax.Array | int
) -> Dict[str, jax.Array]:
    """Prepare indices for controlling the execution of the looping pipeline.

    Args:
        num_loops: Number of loops in the pipeline, or separate stage layers per device. num_loops=1 is equivalent to a non-looping pipeline.
        num_microbatches: Number of microbatches to split the batch into.
        num_stages: Number of stages/devices the pipeline is distributed over.
        stage_index: Index of the stage/device in the pipeline.

    Returns:
        Dictionary of indices for controlling the execution of the pipeline.
    """
    num_iterations = num_loops * num_microbatches + num_stages - 1
    index_array = -jnp.ones((num_iterations,), dtype=jnp.int32)
    # Only first stage uses inputs. Looping communications from last
    # stage are buffered in the inputs, so we repeatedly iterate over
    # the inputs.
    input_indices = jnp.where(
        stage_index == 0,
        index_array.at[: num_loops * num_microbatches].set(
            jnp.tile(jnp.arange(num_microbatches), reps=(num_loops,))
        ),
        index_array,
    )
    # For the first stage, identify input indices that we use to buffer
    # the communications from the last stage. For all other stages, we
    # use the last state from the previous stage as input.
    update_indices = jnp.where(
        stage_index == 0,
        index_array.at[num_stages : num_stages + (num_loops - 1) * num_microbatches].set(
            jnp.tile(jnp.arange(num_microbatches), reps=(num_loops - 1,))
        ),
        index_array,
    )
    # For the last stage, we use the outputs of the last loop as the
    # final outputs.
    output_indices = jnp.where(
        stage_index == num_stages - 1,
        index_array.at[-num_microbatches:].set(jnp.arange(num_microbatches)),
        index_array,
    )
    # For all stages, we iterate over the parameters of the different loops.
    # We use the 0-index for indices that fall into the pipeline bubble.
    params_indices = jnp.zeros_like(index_array)
    for i in range(num_loops):
        start_index = stage_index + i * num_microbatches
        params_indices = jax.lax.dynamic_update_slice_in_dim(
            params_indices,
            jnp.full(shape=(num_microbatches,), fill_value=i, dtype=params_indices.dtype),
            start_index,
            axis=0,
        )
    return {
        "input": input_indices,
        "output": output_indices,
        "update": update_indices,
        "params": params_indices,
    }

In [79]:
num_stages = 3
num_loops = 2
num_microbatches = 4
for i in range(num_stages):
    indices = prepare_looping_pipeline_indices(
        num_loops=num_loops,
        num_microbatches=num_microbatches,
        num_stages=num_stages,
        stage_index=i,
    )
    s = ["step  : " + " ".join(f"{t:2d}" for t in range(len(indices["input"])))]
    for k, v in indices.items():
        s.append(f"{k:6s}: " + " ".join(f"{t:2d}" for t in v))
    max_len = max(map(len, s))
    s.insert(0, (f" Stage Index {i} ").center(max_len, "="))
    print("\n".join(s) + "\n")

step  :  0  1  2  3  4  5  6  7  8  9
input :  0  1  2  3  0  1  2  3 -1 -1
output: -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
update: -1 -1 -1  0  1  2  3 -1 -1 -1
params:  0  0  0  0  1  1  1  1  0  0

step  :  0  1  2  3  4  5  6  7  8  9
input : -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
output: -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
update: -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
params:  0  0  0  0  0  1  1  1  1  0

step  :  0  1  2  3  4  5  6  7  8  9
input : -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
output: -1 -1 -1 -1 -1 -1  0  1  2  3
update: -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
params:  0  0  0  0  0  0  1  1  1  1



In [81]:
@jax.named_scope("pipeline")
def execute_looping_pipeline(
    module: nn.Module,
    params: PyTree,
    x: jax.Array,
    rngs: PyTree,
    *args,
    num_loops: int,
    num_microbatches: int,
    model_axis_name: str,
    **kwargs,
):
    """Execute a looping pipeline of stages on a batch of data.

    Uses a breadth-first strategy to execute the pipeline stages in parallel.

    Args:
        module: Flax module representing a single pipeline stage to execute.
        params: PyTree of parameters for the pipeline stages.
        x: Batch of input data, only needed on device of the first stage. Data will be split into micro-batches.
        rngs: PyTree of random number generators for the pipeline stages.
        *args: Additional arguments to the module.
        num_loops: Number of loops in the pipeline, or separate stage layers per device. num_loops=1 is equivalent to a non-looping pipeline.
        num_microbatches: Number of micro-batches to split the batch into.
        model_axis_name: Name of the model axis in the mesh/shard_map.
        **kwargs: Additional keyword arguments to the module.

    Returns:
        Output of the last stage of the pipeline, with equivalent shape to input x. For devices that are not
        the last stage, the output is zeros.
    """
    num_stages = jax.lax.psum(1, model_axis_name)
    assert num_stages > 1, "Pipeline must have at least 2 stages."
    stage_index = jax.lax.axis_index(model_axis_name)
    # Structure the input data into micro-batches.
    batch_size = x.shape[0]
    assert (
        batch_size % num_microbatches == 0
    ), f"Batch size {batch_size} must be divisible by number of microbatches {num_microbatches}"
    microbatch_size = batch_size // num_microbatches
    microbatches = jnp.reshape(x, (num_microbatches, microbatch_size, *x.shape[1:]))
    last_state = jnp.zeros_like(microbatches[0])
    outputs = jnp.zeros_like(microbatches)
    # Prepare indices for each stage.
    indices = prepare_looping_pipeline_indices(
        num_loops=num_loops,
        num_microbatches=num_microbatches,
        num_stages=num_stages,
        stage_index=stage_index,
    )
    num_iterations = indices["input"].shape[0]
    pipeline_state = PipelineState(
        inputs=microbatches,
        outputs=outputs,
        input_indices=indices["input"],
        output_indices=indices["output"],
        update_indices=indices["update"],
        params_indices=indices["params"],
        last_state=last_state,
        rngs=rngs,
    )
    # Execute the pipeline via a jax fori_loop. Alternatively, a
    # scan could be used to execute the pipeline.
    pipeline_fn = functools.partial(
        execute_looping_pipeline_step,
        *args,
        module=module,
        params=params,
        model_axis_name=model_axis_name,
        **kwargs,
    )
    pipeline_state = jax.lax.fori_loop(
        0,
        num_iterations,
        body_fun=pipeline_fn,
        init_val=pipeline_state,
    )
    # Return the final outputs, reshaped as original input.
    outputs = pipeline_state.outputs
    return jnp.reshape(outputs, (batch_size, *outputs.shape[2:]))

In [83]:
class LoopingPipelineModule(nn.Module):
    num_loops: int
    model_axis_name: str
    num_microbatches: int
    module_fn: Callable[..., nn.Module]

    @nn.compact
    def __call__(self, x: jax.Array, *args, **kwargs):
        if self.is_initializing():
            # During initialization, we want to create a separate set of parameters
            # for each loop. We do this by scanning the module during init. Note that
            # we do not need to execute the pipeline, since we only need to create the
            # parameters.
            sample_microbatch = x[:: self.num_microbatches]
            module = self.module_fn()
            scan_fn = nn.scan(
                lambda module, carry, _: (module(carry, *args, **kwargs), None),
                variable_axes={"params": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.num_loops,
            )
            out, _ = scan_fn(module, sample_microbatch, ())
            return jnp.repeat(out, self.num_microbatches, axis=0)
        else:
            # During the forward pass, we extract the initialized parameters for
            # all loops. In the pipeline, we then sub-index the parameters based on
            # the loop index.
            module = self.module_fn()
            params = module.variables
            # Since we make use of a non-flax transformation, we need to pass the
            # RNGs explicitly to the pipeline.
            rngs = {name: self.make_rng(name) for name in self.scope.rngs}
            return execute_looping_pipeline(
                module=module,
                params=params,
                x=x,
                rngs=rngs,
                *args,
                num_loops=self.num_loops,
                num_microbatches=self.num_microbatches,
                model_axis_name=self.model_axis_name,
                **kwargs,
            )

In [85]:
def get_looping_classifier(config: ConfigDict) -> nn.Module:
    looping_model_config = config.copy_and_resolve_references()
    looping_model_config.num_layers = 1
    looping_module_class = functools.partial(
        LoopingPipelineModule,
        num_loops=config.num_layers,
    )
    return PPClassifier(config=looping_model_config, pipeline_module_class=looping_module_class)


config = get_default_pp_classifier_config()
model_lpp = get_looping_classifier(config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)

In [89]:
device_array = np.array(jax.devices()).reshape(-1, config.model_axis_size)
mesh = Mesh(device_array, (config.data_axis_name, config.model_axis_name))
rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
    labels=jax.random.randint(
        data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
    ),
)

In [91]:
def init_fn(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
    init_rng, rng = jax.random.split(rng)
    variables = model.init({"params": init_rng}, x, train=False)
    params = variables.pop("params")
    state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        rng=rng,
    )
    return state

In [93]:
init_lpp_fn = shard_map(
    functools.partial(init_fn, model=model_lpp),
    mesh,
    in_specs=(P(), P(config.data_axis_name)),
    out_specs=P(),
    check_rep=False,
)
state_lpp_shapes = jax.eval_shape(init_lpp_fn, model_init_rng, batch.inputs)
state_lpp_specs = nn.get_partition_spec(state_lpp_shapes)
pprint(state_lpp_specs.params)

{'input_dense': {'sharded': {'bias': PartitionSpec('model', None),
                             'kernel': PartitionSpec('model', None, None)}},
 'output_dense': {'sharded': {'bias': PartitionSpec('model', None),
                              'kernel': PartitionSpec('model', None, None)}},
 'output_norm': {'sharded': {'bias': PartitionSpec('model', None),
                             'scale': PartitionSpec('model', None)}},
 'pipeline': {'sharded': {'mlp_layers': {'block': {'input_dense': {'bias': PartitionSpec('model', None, None, None),
                                                                   'kernel': PartitionSpec('model', None, None, None, None)},
                                                   'output_dense': {'bias': PartitionSpec('model', None, None, None),
                                                                    'kernel': PartitionSpec('model', None, None, None, None)},
                                                   'pre_norm': {'bias': PartitionSpec

In [95]:
init_lpp_fn = jax.jit(
    shard_map(
        functools.partial(init_fn, model=model_lpp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=state_lpp_specs,
        check_rep=False,
    ),
)
state_lpp = init_lpp_fn(model_init_rng, batch.inputs)

pprint(
    jax.tree_map(lambda x: x.shape, state_lpp.params["pipeline"]["sharded"]["mlp_layers"]["block"])
)

{'input_dense': {'bias': Partitioned(value=(4, 2, 1, 512),
                                     names=('model', None, None, None),
                                     mesh=None),
                 'kernel': Partitioned(value=(4, 2, 1, 512, 512),
                                       names=('model', None, None, None, None),
                                       mesh=None)},
 'output_dense': {'bias': Partitioned(value=(4, 2, 1, 512),
                                      names=('model', None, None, None),
                                      mesh=None),
                  'kernel': Partitioned(value=(4, 2, 1, 512, 512),
                                        names=('model', None, None, None, None),
                                        mesh=None)},
 'pre_norm': {'bias': Partitioned(value=(4, 2, 1, 512),
                                  names=('model', None, None, None),
                                  mesh=None),
              'scale': Partitioned(value=(4, 2, 1, 512),
          

  jax.tree_map(lambda x: x.shape, state_lpp.params["pipeline"]["sharded"]["mlp_layers"]["block"])


In [97]:
train_step_lpp_fn = jax.jit(
    shard_map(
        functools.partial(train_step_pp, config=config),
        mesh,
        in_specs=(state_lpp_specs, P(), P(config.data_axis_name)),
        out_specs=(state_lpp_specs, P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)
state_shapes, metric_shapes = jax.eval_shape(
    train_step_lpp_fn,
    state_lpp,
    None,
    batch,
)
metrics_lpp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_lpp, metrics_lpp = train_step_lpp_fn(state_lpp, metrics_lpp, batch)

  rngs = jax.tree_map(lambda rng: jax.random.split(rng, 2), state.rngs)
  rngs, step_rngs = jax.tree_map(lambda x: x[0], rngs), jax.tree_map(lambda x: x[1], rngs)
  params = jax.tree_map(lambda x: x[params_index], params)
  metrics_lpp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
  rngs = jax.tree_map(lambda rng: jax.random.split(rng, 2), state.rngs)
  rngs, step_rngs = jax.tree_map(lambda x: x[0], rngs), jax.tree_map(lambda x: x[1], rngs)
  params = jax.tree_map(lambda x: x[params_index], params)


In [99]:
print(f"Number of parameters: {get_num_params(state_lpp):_}")

Number of parameters: 5_842_984


In [101]:
for _ in range(15):
    state_lpp, metrics_lpp = train_step_lpp_fn(state_lpp, metrics_lpp, batch)
final_metrics_lpp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_lpp, final_metrics_lpp = train_step_lpp_fn(state_lpp, final_metrics_lpp, batch)
print_metrics(final_metrics_lpp, title="Final Metrics - Looping Pipeline")