# Data parallel training with JAX

This notebook compares two methods (pmap, data sharding + jit) for data parallel training with JAX in combination with a simple flax neural network model. 

It was run on a compute node with 2 Quadro RTX 6000 GPUs with jax 0.4.31 and flax 0.8.6.

In [1]:
import numpy as np

import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

import flax.linen as nn
from flax import jax_utils
from flax.training.train_state import TrainState

import optax

In [2]:
print(f'JAX devices: {jax.devices()}')

JAX devices: [CudaDevice(id=0), CudaDevice(id=1)]


In [3]:
# a simple multilayer perceptron
class MLP(nn.Module):
  n_layers: int
  hidden_dim: int
  output_dim: int
  def setup(self):
    self.layers = [nn.Dense(self.hidden_dim) for _ in range(self.n_layers)]
    self.output_layer = nn.Dense(self.output_dim)
  def __call__(self, x: jax.Array):
    for layer in self.layers:
      x = layer(x)
      x = nn.relu(x)
    x = self.output_layer(x)
    return x

### Available modes:
'jit_single' Use single GPU, jit train function

'jit_multi': Use multiple GPUs, split data overs GPUs, copy trainig parameters to all GPUs, jit train function. Recommended option

'pmap': Use multiple GPUs, reshape data with new axis of size n_gpus, replicate trainig parameters to all GPUs, pmap train function. Depreceated but sill works well


### Performance:
pmap and jit_multi performs similar, jit_single is slower for large models/batchsizes and faster otherwise

For the hyperparameters below one training steps takes: jit_single: 1.2s, jit_multi: 0.7s, pmap: 0.7s

For batchsize 256 * 54, one training steps takes: jit_single: out of memory, jit_multi: 1.3s, pmap: 1.3s

In [4]:
# for mode description, see above
mode = 'jit_single' 
# create a large model
optimizer = optax.adam(learning_rate=1e-4)
key = jax.random.PRNGKey(0)
batchsize=256 * 27 # if jit_multi/pmap, can handle up to * 54
input_dim=1024
hidden_dim = 2056 
n_layers = 90 # if jit_multi/pmap, can handle up to 130
batch_data = jnp.ones((batchsize, input_dim), dtype=jnp.float32)
label_data = jnp.ones((batchsize, 10), dtype=jnp.int32)

model = MLP(n_layers=n_layers, hidden_dim=hidden_dim, output_dim=10)

def init_fn(k, x, model, optimizer):
  variables = model.init(k, x) # Initialize the model.
  state = TrainState.create( # Create a `TrainState`.
    apply_fn=model.apply,
    params=variables,
    tx=optimizer)
  return state

initialized_state = init_fn(key, batch_data, model, optimizer)


2024-12-05 12:10:33.894983: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [5]:
# training function
def train_step(state, x, y):
  def loss_fn(params, x, y):
    y_pred = state.apply_fn(params, x)
    return jnp.mean((y_pred - y) ** 2)
  loss_grad_fn = jax.value_and_grad(loss_fn)
  loss, grads = loss_grad_fn(state.params, x, y)
  if mode == 'pmap':
    grads = jax.lax.pmean(grads, axis_name='device')
  state = state.apply_gradients(grads=grads)
  return state, loss

In [6]:
if mode == 'jit_multi':
  # Mesh sharding splits data onto devices along the 'data' axis
  mesh = Mesh(devices=np.array(jax.devices()),
              axis_names=('data'))

  def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
    return NamedSharding(mesh, pspec)

  # data sharding splits data onto devices along the 'data' axis
  data_sharding = mesh_sharding(PartitionSpec('data'))
  batch_data = jax.device_put(batch_data, data_sharding)
  label_data = jax.device_put(label_data, data_sharding)

  # this copies the state to all devices
  initialized_state = jax.device_put(initialized_state, mesh_sharding(()))

  train_step_compiled = jax.jit(train_step)

elif mode == 'pmap':
  # replicate the initialized state across all devices
  initialized_state  = jax_utils.replicate(initialized_state)
  # extra axis of size jax.device_count() for data parallelism
  batch_data = jnp.reshape(batch_data, (jax.device_count(), batchsize // jax.device_count(), input_dim))
  label_data = jnp.reshape(label_data, (jax.device_count(), batchsize // jax.device_count(), 10))
  # pmap the train_step function
  train_step_compiled = jax.pmap(train_step, axis_name="device", in_axes=(0, 0, 0))

elif mode == 'jit_single':
  train_step_compiled = jax.jit(train_step)

else:
  raise ValueError(f'Unknown mode: {mode}')


In [7]:
state, loss = train_step_compiled(initialized_state, batch_data, label_data)

2024-12-05 12:10:42.021469: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 8.78GiB (9431760227 bytes) by rematerialization; only reduced to 9.10GiB (9772517936 bytes), down from 13.81GiB (14831659588 bytes) originally


In [8]:
state.opt_state[0].mu['params']['layers_0']['kernel'].shape

(1024, 2056)

In [9]:
if mode == 'pmap':
  # we can see that the model is copied across all devices
  jax.debug.visualize_array_sharding(state.opt_state[0].mu['params']['layers_0']['kernel'][0])
  jax.debug.visualize_array_sharding(state.opt_state[0].mu['params']['layers_0']['kernel'][1])
  jax.debug.visualize_array_sharding(batch_data[0])
  jax.debug.visualize_array_sharding(batch_data[1])
else:
  # we can see that the model is copied across all devices and the data is split if multiprocessing is used
  jax.debug.visualize_array_sharding(state.opt_state[0].mu['params']['layers_0']['kernel'])
  jax.debug.visualize_array_sharding(batch_data)

In [10]:
%%timeit
new_state, loss = jax.block_until_ready(train_step_compiled(initialized_state, batch_data, label_data))

1.19 s ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
