# Code adapted from 
- https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/scaling/JAX/data_parallel_intro.html
- http://d2l.ai/chapter_computational-performance/multiple-gpus-concise.html

In [3]:
import os

# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True

flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
    flags += " --xla_force_host_platform_device_count=8"  # Simulate 8 devices
    # Enforce CPU-only execution
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
    # GPU flags
    flags += (
        "--xla_gpu_enable_triton_softmax_fusion=true "
        "--xla_gpu_triton_gemm_any=false "
        "--xla_gpu_enable_async_collectives=true "
        "--xla_gpu_enable_latency_hiding_scheduler=true "
        "--xla_gpu_enable_highest_priority_async_stream=true "
    )
os.environ["XLA_FLAGS"] = flags

In [5]:
import functools
from typing import Any, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [12]:
a = jnp.arange(8)
mesh = Mesh(np.array(jax.devices()), ("i",))
sharding = NamedSharding(
    mesh,
    P("i"),
)
a_sharded = jax.device_put(a, sharding)
jax.debug.visualize_array_sharding(a_sharded)

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


In [14]:
out = nn.tanh(a_sharded)
jax.debug.visualize_array_sharding(out)

In [16]:
# we can shard the batch dimension of the input x over the i axis, and the output dimension of the weight matrix w and bias b over the j axis
mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ("i", "j"))
batch_size = 192
input_dim = 64
output_dim = 128
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim))
w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim))
b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,))
x_sharded = jax.device_put(x, NamedSharding(mesh, P("i", None)))
w_sharded = jax.device_put(w, NamedSharding(mesh, P(None, "j")))
b_sharded = jax.device_put(b, NamedSharding(mesh, P("j")))
out = jnp.dot(x_sharded, w_sharded) + b_sharded
print("Output shape", out.shape)
jax.debug.visualize_array_sharding(out)

Output shape (192, 128)


In [22]:
# The transformation shard_map has been developed as an alternative to jax.pmap, which gives us more explicit control over the parallelization and communication

def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array:
    print("Local x shape", x.shape)
    print("Local w shape", w.shape)
    print("Local b shape", b.shape)
    return jnp.dot(x, w) + b

matmul_sharded = shard_map(
    matmul_fn, mesh, in_specs=(P("i", None), P(None, "j"), P("j")), out_specs=P("i", "j")
)

y = matmul_sharded(x_sharded, w_sharded, b_sharded)
print("Output shape", y.shape)
jax.debug.visualize_array_sharding(y)

Local x shape (48, 64)
Local w shape (64, 64)
Local b shape (64,)
Output shape (192, 128)


## Parallelizing operations, gathering and scattering

In [26]:
@functools.partial(shard_map, mesh=mesh, in_specs=P("i", "j"), out_specs=P("i", "j"))
def parallel_normalize(x: jax.Array) -> jax.Array:
    mean = jax.lax.pmean(x, axis_name="j")
    std = jax.lax.pmean((x - mean) ** 2, axis_name="j") ** 0.5
    return (x - mean) / std

In [28]:
out = parallel_normalize(x)
out = jax.device_get(out)
print("Mean", out.mean())
print("Std", out.std())

Mean -4.149236e-08
Std 1.0


In [30]:
@functools.partial(
    shard_map, mesh=mesh, in_specs=(P("i", None), P("i", None)), out_specs=P("i", None)
)
def matmul_with_weight_gather(x: jax.Array, w: jax.Array) -> jax.Array:
    print("Original w shape", w.shape)
    w_gathered = jax.lax.all_gather(w, axis_name="i", axis=0, tiled=True)
    print("Gathered w shape", w_gathered.shape)
    y = jnp.dot(x, w_gathered)
    return y


out = matmul_with_weight_gather(x, w)
out = jax.device_get(out)
np.testing.assert_array_equal(out, jnp.dot(x, w))

Original w shape (16, 128)
Gathered w shape (64, 128)


In [32]:
@functools.partial(shard_map, mesh=mesh, in_specs=P("i", None), out_specs=P("i", None))
def scatter_example(x: jax.Array) -> jax.Array:
    x_scatter = jax.lax.psum_scatter(x, axis_name="i", scatter_dimension=1)
    return x_scatter


x_exmp = np.array(
    [
        [3, 1, 4, 1],
        [5, 9, 2, 6],
        [5, 3, 5, 8],
        [9, 7, 1, 2],
    ]
)
out = scatter_example(x_exmp)
print("Output", out)

Output [22 20 12 17]


In [34]:
@functools.partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"))
def ppermute_example(x: jax.Array) -> jax.Array:
    axis_size = mesh.shape["i"]
    x_perm = jax.lax.ppermute(
        x, axis_name="i", perm=[(i, (i + 1) % axis_size) for i in range(axis_size)]
    )
    return x_perm


x_exmp = np.arange(4)
out = ppermute_example(x_exmp)
print("Output", out)

Output [3 0 1 2]


## Axis Indexing

In [36]:
axis_idx_fn = jax.jit(
    shard_map(
        lambda: jnp.stack(
            [
                jax.lax.axis_index("i"),  # Device index in mesh along the "i" axis
                jax.lax.axis_index("j"),  # Device index in mesh along the "j" axis
            ],
            axis=-1,
        )[None],
        mesh,
        in_specs=P(),
        out_specs=P(
            ("i", "j"),
        ),
    )
)
out = axis_idx_fn()
out = jax.device_get(out)
for i in range(out.shape[0]):
    print(f"Device {i}: i-axis={out[i, 0]}, j-axis={out[i, 1]}")

Device 0: i-axis=0, j-axis=0
Device 1: i-axis=0, j-axis=1
Device 2: i-axis=1, j-axis=0
Device 3: i-axis=1, j-axis=1
Device 4: i-axis=2, j-axis=0
Device 5: i-axis=2, j-axis=1
Device 6: i-axis=3, j-axis=0
Device 7: i-axis=3, j-axis=1


In [38]:
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
    """Folds the random number generator over the given axis.

    This is useful for generating a different random number for each device
    across a certain axis (e.g. the model axis).

    Args:
        rng: The random number generator.
        axis_name: The axis name to fold the random number generator over.

    Returns:
        A new random number generator, different for each device index along the axis.
    """
    axis_index = jax.lax.axis_index(axis_name)
    return jax.random.fold_in(rng, axis_index)

In [40]:
fold_fn = jax.jit(
    shard_map(
        functools.partial(fold_rng_over_axis, axis_name="i"),
        mesh,
        in_specs=P(),
        out_specs=P(
            ("i", "j"),
        ),
    )
)
rng = jax.random.PRNGKey(0)
out = fold_fn(rng)
out = jax.device_get(out)
for i in range(out.shape[0] // 2):
    print(f"Device {i}: RNG={out[2*i:2*i+2]}")

Device 0: RNG=[1797259609 2579123966]
Device 1: RNG=[1797259609 2579123966]
Device 2: RNG=[ 928981903 3453687069]
Device 3: RNG=[ 928981903 3453687069]
Device 4: RNG=[4146024105 2718843009]
Device 5: RNG=[4146024105 2718843009]
Device 6: RNG=[2467461003 3840466878]
Device 7: RNG=[2467461003 3840466878]


# Data Parallelism

## Imports

In [44]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

In [46]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

In [48]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [50]:
from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics

In [52]:
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
    """Folds the random number generator over the given axis.

    This is useful for generating a different random number for each device
    across a certain axis (e.g. the model axis).

    Args:
        rng: The random number generator.
        axis_name: The axis name to fold the random number generator over.

    Returns:
        A new random number generator, different for each device index along the axis.
    """
    axis_index = jax.lax.axis_index(axis_name)
    return jax.random.fold_in(rng, axis_index)

## DP

Non-JAX: We essentially just call `nn.DataParallel` on the net. This chunks the data and sums up gradients for us. 

In [None]:
import torch
from torch import nn
from d2l import torch as d2l

def train(net, num_gpus, batch_size, lr):
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
    devices = [d2l.try_gpu(i) for i in range(num_gpus)]
    def init_weights(module):
        if type(module) in [nn.Linear, nn.Conv2d]:
            nn.init.normal_(module.weight, std=0.01)
    net.apply(init_weights)
    # Set the model on multiple GPUs
    net = nn.DataParallel(net, device_ids=devices)
    trainer = torch.optim.SGD(net.parameters(), lr)
    loss = nn.CrossEntropyLoss()
    timer, num_epochs = d2l.Timer(), 10
    animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
    for epoch in range(num_epochs):
        net.train()
        timer.start()
        for X, y in train_iter:
            trainer.zero_grad()
            X, y = X.to(devices[0]), y.to(devices[0])
            l = loss(net(X), y)
            l.backward()
            trainer.step()
        timer.stop()
        animator.add(epoch + 1, (d2l.evaluate_accuracy_gpu(net, test_iter),))
    print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
          f'on {str(devices)}')

JAX Tldr:
- `shard_map` allows us to write the model code as though it operates on a single device, with the following exceptions:
    - Wrap the initialization and training step function with `shard_map`
    - Split the RNG key across devices (for Dropout)
- We use `jax.lax.pmean` and `jax.lax.psum` to communicate gradients and loss across devices

In [71]:
class DPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [73]:
data_config = ConfigDict(
    dict(
        batch_size=128,
        num_classes=10,
        input_size=784,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=512,
        dropout_rate=0.1,
        dtype=jnp.bfloat16,
        num_classes=data_config.num_classes,
        data_axis_name="data",
    )
)
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        seed=42,
    )
)

In [75]:
model_dp = DPClassifier(config=config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)

In [77]:
rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
    labels=jax.random.randint(
        data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
    ),
)

In [78]:
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
    init_rng, rng = jax.random.split(rng)
    variables = model.init({"params": init_rng}, x, train=False)
    params = variables.pop("params")
    state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        rng=rng,
    )
    return state

In [81]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, (config.data_axis_name,))

In [83]:
init_dp_fn = jax.jit(
    shard_map(
        functools.partial(init_dp, model=model_dp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=P(),
        check_rep=False,
    ),
)

In [85]:
state_dp = init_dp_fn(model_init_rng, batch.inputs)
print("DP Parameters")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), state_dp.params))

DP Parameters
{'input_dense': {'bias': ((512,),
                          NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec())),
                 'kernel': ((784, 512),
                            NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec()))},
 'output_dense': {'bias': ((10,),
                           NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec())),
                  'kernel': ((512, 10),
                             NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec()))}}


  pprint(jax.tree_map(lambda x: (x.shape, x.sharding), state_dp.params))


In [87]:
def loss_fn(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[jax.Array, Dict[str, Any]]:
    # Since dropout masks vary across the batch dimension, we want each device to generate a
    # different mask. We can achieve this by folding the rng over the data axis, so that each
    # device gets a different rng and thus mask.
    dropout_rng = fold_rng_over_axis(rng, config.data_axis_name)
    # Remaining computation is the same as before for single device.
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": dropout_rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = batch.inputs.shape[0]
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

In [89]:
def train_step_dp(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=loss_fn,
    )
    # Update parameters. We need to sync the gradients across devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = jax.tree_map(lambda g: jax.lax.pmean(g, axis_name=config.data_axis_name), grads)
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas. Alternatively, we could keep the metrics separate
    # and only synchronize them before logging. For simplicity, we sum them here.
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
        )
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [91]:
train_step_dp_fn = jax.jit(
    shard_map(
        train_step_dp,
        mesh,
        in_specs=(P(), P(), P(config.data_axis_name)),
        out_specs=(P(), P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)

In [93]:
_, metric_shapes = jax.eval_shape(
    train_step_dp_fn,
    state_dp,
    None,
    batch,
)
metrics_dp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

  grads = jax.tree_map(lambda g: jax.lax.pmean(g, axis_name=config.data_axis_name), grads)
  step_metrics = jax.tree_map(
  metrics_dp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)


In [95]:
for _ in range(15):
    state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch)
final_metrics_dp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_dp, final_metrics_dp = train_step_dp_fn(state_dp, final_metrics_dp, batch)
print_metrics(final_metrics_dp)

In [99]:
print("DP Parameters")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), state_dp.params))
print("Metrics")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), final_metrics_dp))

## FSDP

In [59]:
Parameter = jax.Array | nn.Partitioned

In [61]:
@jax.named_scope("shard_params")
def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree:
    """Shard parameters across the given mesh axis.

    Args:
        params: The parameters to shard.
        axis_name: The axis to shard parameters across.
        min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.

    Returns:
        PyTree of same structure as params, but with leaves sharded over new axis if possible.
    """
    axis_idx = jax.lax.axis_index(axis_name)
    axis_size = jax.lax.psum(1, axis_name)

    def _split(x: Parameter) -> Parameter:
        if isinstance(x, nn.Partitioned):
            value, names = x.value, x.names
        else:
            value = x
            names = (None,) * value.ndim
        if axis_name in names:
            logging.warning(
                f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}."
            )
            return x
        elif value.size <= min_weight_size:
            logging.info(
                f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}."
            )
            return x
        else:
            shape = value.shape
            idx = np.argsort(shape)[::-1]  # Shard along largest possible axis.
            for i in idx:
                if shape[i] % axis_size == 0 and names[i] is None:
                    split_size = shape[i] // axis_size
                    p_sharded = nn.Partitioned(
                        value=lax.dynamic_slice_in_dim(  # Shard to keep on present device.
                            value, axis_idx * split_size, split_size, axis=i
                        ),
                        names=names[:i] + (axis_name,) + names[i + 1 :],
                    )
                    return p_sharded
            logging.warning(
                f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found."
            )
            return x

    return jax.tree_util.tree_map(
        _split,
        params,
        is_leaf=lambda x: isinstance(
            x, nn.Partitioned
        ),  # Consider a nn.Partitioned object as a leaf.
    )

In [63]:
def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str):
    """Gathering with averaging gradients across replicas."""
    axis_size = jax.lax.psum(1, axis_name)

    # Define a custom gradient for the gather operation.
    @jax.custom_gradient
    def f(x):
        def grad_fn(g):
            # pmean_scatter
            return (
                jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size
            )

        return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn

    return f(x)

@jax.named_scope("gather_params")
def gather_params(params: PyTree, axis_name: str) -> PyTree:
    """Gather parameters from all replicas across the given axis.

    Args:
        params: The parameters to gather.
        axis_name: The axis to gather parameters across.

    Returns:
        PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.
    """

    def _gather(p: Parameter) -> Parameter:
        if isinstance(p, nn.Partitioned) and axis_name in p.names:
            param_shard = p.names
            shard_axis = param_shard.index(axis_name)
            value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name)
            # If there are any other axes that are sharded, we need to keep the partitioned structure.
            # Otherwise, we can return the value directly.
            param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :]
            if any([name is not None for name in param_shard]):
                return nn.Partitioned(value, param_shard)
            else:
                return value
        else:
            return p

    return jax.tree_util.tree_map(_gather, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))

In [65]:
# For computation, we want to gather params, compute, then shard params

def shard_module_params(
    target: nn.Module | Callable, axis_name: str, min_weight_size: int = 2**18
) -> nn.Module | Callable:
    """Shard parameters of a module across replicas.

    Args:
        target: The module to shard.
        axis_name: The axis name to shard parameters across.
        min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.

    Returns:
        The module with sharded parameters.
    """
    return nn.map_variables(
        target,
        trans_in_fn=functools.partial(gather_params, axis_name=axis_name),
        trans_out_fn=functools.partial(
            shard_params, axis_name=axis_name, min_weight_size=min_weight_size
        ),
        mapped_collections="params",
        mutable=True,
    )

In [67]:
class FSDPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        sharded_dense = shard_module_params(
            nn.Dense,
            axis_name=self.config.data_axis_name,
            min_weight_size=self.config.min_weight_size,
        )
        x = sharded_dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = sharded_dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [101]:
config.model.min_weight_size = 2**4
model_fsdp = FSDPClassifier(config=config.model)

In [103]:
init_fsdp_fn = shard_map(
    functools.partial(init_dp, model=model_fsdp),
    mesh,
    in_specs=(P(), P(config.data_axis_name)),
    out_specs=P(),
    check_rep=False,
)
state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs)
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
print("RNG", state_fsdp_specs.rng)
print("\nParameters")
pprint(state_fsdp_specs.params)
print("\nOptimizer state")
pprint(state_fsdp_specs.opt_state[0])

RNG PartitionSpec()

Parameters
{'input_dense': {'bias': PartitionSpec('data',),
                 'kernel': PartitionSpec('data', None)},
 'output_dense': {'bias': PartitionSpec(),
                  'kernel': PartitionSpec('data', None)}}

Optimizer state
ScaleByAdamState(count=PartitionSpec(), mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}, nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}})


In [105]:
init_fsdp_fn = jax.jit(
    shard_map(
        functools.partial(init_dp, model=model_fsdp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=state_fsdp_specs,
        check_rep=False,
    )
)
state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs)

In [107]:
print("FSDP Parameters")
pprint(jax.tree_map(lambda x: x.shape, jax.device_get(state_fsdp.params)))

FSDP Parameters
{'input_dense': {'bias': Partitioned(value=(512,), names=('data',), mesh=None),
                 'kernel': Partitioned(value=(784, 512),
                                       names=('data', None),
                                       mesh=None)},
 'output_dense': {'bias': (10,),
                  'kernel': Partitioned(value=(512, 10),
                                        names=('data', None),
                                        mesh=None)}}


  pprint(jax.tree_map(lambda x: x.shape, jax.device_get(state_fsdp.params)))


In [109]:
def sync_gradients(
    grads: PyTree,
    axis_names: Sequence[str],
) -> PyTree:
    """Synchronize gradients across devices.

    Gradients for parameters that are replicated over a given axis are averaged across devices.
    Parameters that are partitioned over a given axis are considered to already have a mean of
    the gradients on each device, and hence do not need to be altered.

    Args:
        grads: The gradients to synchronize.
        axis_names: The axis names to synchronize gradients across.

    Returns:
        The gradients averaged over the specified axes if they are replicated.
    """

    def sync_grad(g: Parameter) -> Parameter:
        if isinstance(g, nn.Partitioned):
            # Tree leaves for flattening potentially nested axis (multiple names can exist for single array axis).
            replication_axis_names = [
                name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names)
            ]
            if len(replication_axis_names) == 0:
                # Parameters partitioned over all axes.
                return g
            else:
                # Average over remaining replicated axes.
                return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names))
        else:
            # Parameters are replicated over all axes.
            return jax.lax.pmean(g, axis_name=axis_names)

    return jax.tree_map(sync_grad, grads, is_leaf=lambda x: isinstance(x, nn.Partitioned))

In [111]:
def train_step_fsdp(
    state: TrainState,
    metrics: Metrics,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=loss_fn,
    )
    # Update parameters. We need to sync the gradients across devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = sync_gradients(grads, (config.data_axis_name,))
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas. Alternatively, we could keep the metrics separate
    # and only synchronize them before logging. For simplicity, we sum them here.
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
        )
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [113]:
train_step_fsdp_fn = jax.jit(
    shard_map(
        train_step_fsdp,
        mesh,
        in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)),
        out_specs=(state_fsdp_specs, P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)
_, metric_shapes = jax.eval_shape(
    train_step_fsdp_fn,
    state_fsdp,
    None,
    batch,
)
metrics_fsdp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

  return jax.tree_map(sync_grad, grads, is_leaf=lambda x: isinstance(x, nn.Partitioned))
  step_metrics = jax.tree_map(
  metrics_fsdp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)


In [115]:
for _ in range(15):
    state_fsdp, metrics_fsdp = train_step_fsdp_fn(state_fsdp, metrics_fsdp, batch)
final_metrics_fsdp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(state_fsdp, final_metrics_fsdp, batch)
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")