# distributed training:
- is a model paradigm where training workload is spread acress multiple worker nodes.
- used for large models.
### Ways to perform distributed training:
- DistributedDataParallel
- Fully Sharded Data Parallel
- Tensor Parallel
- Device Mesh
- Remote Procedure Call distributed training
- custom Extensions


In [1]:
import os

# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data  # For dataset input.

2025-06-22 10:46:10.547143: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750589170.761449      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750589170.823454      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
devices

[CudaDevice(id=0), CudaDevice(id=1)]

In [9]:
# Retrieve the local available gpu devices.
devices = jax.devices("gpu")  # Assume it has 8 local GPUs.

# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(
    shape=(2,1), axis_names=["data", "model"], devices=devices
)

# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)

# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(
    axes=("data", None, None, None), device_mesh=mesh
)

In [6]:
layout_2d

<TensorLayout axes=('model', 'data'), device_mesh=<DeviceMesh shape=(2, 1), axis_names=['data', 'model']>>

In [10]:
replicated_layout_4d

<TensorLayout axes=('data', None, None, None), device_mesh=<DeviceMesh shape=(2, 1), axis_names=['data', 'model']>>

## Data Parallel
- model weights are replicated across all devices in deviceMesh and each device processes a  portion of input data

In [26]:
def b():
    # Create DataParallel with list of devices.
    # As a shortcut, the devices can be skipped,
    # and Keras will detect all local available devices.
    # E.g. data_parallel = DataParallel()
    data_parallel = keras.distribution.DataParallel(devices=devices)

    # Or you can choose to create DataParallel with a 1D `DeviceMesh`.
    mesh_1d = keras.distribution.DeviceMesh(
        shape=(2,), axis_names=["data"], devices=devices
    )
    data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)

    inputs = np.random.normal(size=(128, 28, 28, 1))
    labels = np.random.normal(size=(128, 10))
    dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)

    # Set the global distribution.
    keras.distribution.set_distribution(data_parallel)

    # Note that all the model weights from here on are replicated to
    # all the devices of the `DeviceMesh`. This includes the RNG
    # state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
    # `model.evaluate` will be split evenly on the batch dimension, and sent to
    # all the devices. You don't have to do any manual aggregration of losses,
    # since all the computation happens in a global context.
    inputs = layers.Input(shape=(28, 28, 1))
    y = layers.Flatten()(inputs)
    y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
    y = layers.Dropout(0.4)(y)
    y = layers.Dense(units=10, activation="softmax")(y)
    model = keras.Model(inputs=inputs, outputs=y)

    model.compile(loss="mse")
    model.fit(dataset, epochs=3)
    model.evaluate(dataset)

In [2]:
def a():
    inputs = np.random.normal(size=(128, 28, 28, 1))
    labels = np.random.normal(size=(128, 10))
    dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)

    inputs = layers.Input(shape=(28, 28, 1))
    y = layers.Flatten()(inputs)
    y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
    y = layers.Dropout(0.4)(y)
    y = layers.Dense(units=10, activation="softmax")(y)
    model = keras.Model(inputs=inputs, outputs=y)

    model.compile(loss="mse")
    model.fit(dataset, epochs=3)
    model.evaluate(dataset)

In [4]:
%time a()

Epoch 1/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.9947  
Epoch 2/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9111
Epoch 3/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.8611
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.8228  
CPU times: user 1.02 s, sys: 132 ms, total: 1.16 s
Wall time: 1.07 s


In [30]:
%time b()

Epoch 1/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - loss: 0.9529
Epoch 2/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.8703
Epoch 3/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.8244
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.7696  
CPU times: user 1.09 s, sys: 84 ms, total: 1.18 s
Wall time: 1.07 s


In [5]:
from datetime import datetime
start_time = datetime.now()
a()
end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))

Epoch 1/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 1.1406  
Epoch 2/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 1.0297
Epoch 3/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9958
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.9517  
Duration: 0:00:01.000590


In [40]:
from datetime import datetime
start_time = datetime.now()
b()
end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))

Epoch 1/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - loss: 1.0709
Epoch 2/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9911
Epoch 3/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9338
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9021  
Duration: 0:00:01.103680


In [23]:
import numpy as np
from scipy import stats


## Model Parallel and LayoutMap
- MP: split model weights acrosss all devices  useful when model weights are too large.



In [17]:
mesh_2d = keras.distribution.DeviceMesh(
    shape=(2, 1), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)

# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)

model_parallel = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="data")

keras.distribution.set_distribution(model_parallel)

inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)

# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)

Epoch 1/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - loss: 1.0292
Epoch 2/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9319
Epoch 3/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9104
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.8629  


0.872115433216095

In [16]:
keras.distribution.ModelParallel?

[0;31mInit signature:[0m
[0mkeras[0m[0;34m.[0m[0mdistribution[0m[0;34m.[0m[0mModelParallel[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlayout_map[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_dim_name[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Distribution that shards model variables.

Compare to `DataParallel` which replicates the variables across all devices,
`ModelParallel` allows you to shard variables in addition to the input data.

To construct a `ModelParallel` distribution, you need to provide a
`DeviceMesh` and a `LayoutMap`.

1. `DeviceMesh` contains physical device information. The axis names in
    the mesh will be used to map the variable and data layout.
2. `LayoutMap` contains the mapping between variable paths to their
    

# different mesh shapes

In [None]:
'full_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(1, 8), axis_names=["data", "model"], devices=devices
)'