In [1]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

import flax.linen as nn
import optax
import numpy as np
import functools

# --- Configuration ---
INPUT_FEATURES = 64
HIDDEN_DIM = 128
OUTPUT_DIM = 10
LEARNING_RATE = 1e-3
NUM_STEPS = 100

# --- Device Setup ---
NUM_DEVICES = jax.local_device_count()
if NUM_DEVICES < 2:
    print(f"WARNING: Only {NUM_DEVICES} device found. "
          "Replication across devices is trivial but less illustrative.")

# Global batch size should be divisible by number of devices for data parallelism
GLOBAL_BATCH_SIZE = 32 * NUM_DEVICES
DEVICE_BATCH_SIZE = GLOBAL_BATCH_SIZE // NUM_DEVICES

# --- 1. Setup Device Mesh ---
# Create a 1D mesh for data parallelism. The axis 'data' will be used
# to shard the batch dimension and to average gradients.
device_mesh = mesh_utils.create_device_mesh((NUM_DEVICES,))
mesh = Mesh(devices=device_mesh, axis_names=('data',))
print(f"Created Mesh: {mesh}")



Created Mesh: Mesh('data': 1, axis_types=(Auto,))


In [2]:

# --- 2. Define Partitioning Specifications ---
P = PartitionSpec # Alias for convenience

# Specification for Replication: Don't shard along any mesh axis.
# P(None,) means don't shard along the first (and only) mesh axis 'data'.
replicated_spec = P(None,)

# Specification for Data Sharding (Batch Dimension):
# Shard the first dimension (batch) along the 'data' mesh axis.
# Keep other dimensions (features) un-sharded (replicated within a device).
data_sharding_spec = P('data', None) # Shard axis 0, replicate axis 1
label_sharding_spec = P('data',)     # Shard axis 0 (for 1D labels)

# --- Create NamedShardings (binding Specs to the Mesh) ---
replicated_sharding = NamedSharding(mesh, replicated_spec)
data_sharding = NamedSharding(mesh, data_sharding_spec)
label_sharding = NamedSharding(mesh, label_sharding_spec)

print(f"\nReplicated Sharding: {replicated_sharding}")
print(f"Data Sharding: {data_sharding}")
print(f"Label Sharding: {label_sharding}")





Replicated Sharding: NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec(None,), memory_kind=unpinned_host)
Data Sharding: NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec('data', None), memory_kind=unpinned_host)
Label Sharding: NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec('data',), memory_kind=unpinned_host)


In [3]:
# --- 3. Define the Neural Network (Same as before) ---
class SimpleMLP(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim, name="dense_layer_1")(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, name="dense_layer_2")(x)
        return x



In [4]:
# --- 4. Initialize Model and Optimizer with Replication ---
key = jax.random.PRNGKey(0)
model_key, params_key, data_key = jax.random.split(key, 3)

model = SimpleMLP(hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM)

# Create dummy input *per device* shape for initialization if needed,
# though global often works if init isn't shape-dependent.
# Using device shape avoids potential issues if init logic uses input shape.
dummy_device_input = jnp.ones([DEVICE_BATCH_SIZE, INPUT_FEATURES])

# Get abstract structure first (runs on host)
abstract_variables = jax.eval_shape(model.init, model_key, dummy_device_input)
abstract_params = abstract_variables['params']

# Create the sharding specification tree for parameters - all replicated
# Use tree.map to apply the replicated_sharding to every leaf in the params tree
params_sharding_tree = jax.tree.map(lambda _: replicated_sharding, abstract_params)
print("\nParameter Sharding Tree (all replicated):")
print(params_sharding_tree)

# Initialize parameters using JIT with out_shardings to place them directly
# onto devices with the specified (replicated) sharding.
@functools.partial(jax.jit, out_shardings=params_sharding_tree)
def initialize_params(key):
    return model.init(key, dummy_device_input)['params']

print("\nInitializing parameters with replicated sharding...")
params = initialize_params(params_key)
print("Parameter initialization complete.")

# Verify parameter sharding (optional check)
print("Verifying Parameter Sharding:")
jax.tree_util.tree_map_with_path(
    lambda path, x: print(f"  Param: {path}, Shape: {x.shape}, Sharding: {x.sharding}"),
    params
)
# Check that all shardings are indeed the `replicated_sharding` object
assert all(leaf.sharding == replicated_sharding
           for leaf in jax.tree_util.tree_leaves(params))
print("Parameter sharding verified.")




Parameter Sharding Tree (all replicated):
{'dense_layer_1': {'bias': NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec(None,), memory_kind=unpinned_host), 'kernel': NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec(None,), memory_kind=unpinned_host)}, 'dense_layer_2': {'bias': NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec(None,), memory_kind=unpinned_host), 'kernel': NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec(None,), memory_kind=unpinned_host)}}

Initializing parameters with replicated sharding...
Parameter initialization complete.
Verifying Parameter Sharding:
  Param: (DictKey(key='dense_layer_1'), DictKey(key='bias')), Shape: (128,), Sharding: NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec(None,), memory_kind=unpinned_host)
  Param: (DictKey(key='dense_layer_1'), DictKey(key='kernel')), Shape: (64, 128), Sharding: NamedSharding(mesh=Mesh('data':

In [5]:
# Optimizer (Optax)
optimizer = optax.adam(LEARNING_RATE)

print("\nGetting abstract optimizer state structure...")
abstract_opt_state = jax.eval_shape(optimizer.init, params)
print("Abstract Opt State Structure:")
print(abstract_opt_state)

def get_opt_state_sharding(leaf):
    return NamedSharding(mesh, P())

print("\nCreating Optimizer State Sharding Tree (handling scalars)...")
opt_state_sharding_tree = jax.tree_util.tree_map(
    get_opt_state_sharding, abstract_opt_state
)
print("Optimizer State Sharding Tree:")
print(jax.tree_util.tree_map(lambda x: x.spec, opt_state_sharding_tree)) 


@functools.partial(jax.jit, out_shardings=opt_state_sharding_tree)
def initialize_optimizer_state(p):
    return optimizer.init(p)

print("\nInitializing optimizer state with corrected sharding...")
opt_state = initialize_optimizer_state(params)
print("Optimizer state initialization complete.")

print("\nVerifying Optimizer State Sharding:")
jax.tree_util.tree_map_with_path(
     lambda path, x: print(f"  Opt State: {path}, Type: {type(x).__name__}, "
                           f"Shape: {getattr(x, 'shape', 'N/A')}, "
                           f"Sharding: {getattr(x, 'sharding', 'N/A')}"),
     opt_state
)

def check_opt_state_leaf(actual_leaf, named_sharding_leaf):
    intended_spec = named_sharding_leaf.spec 

    if hasattr(actual_leaf, 'sharding'):
        assert actual_leaf.sharding.spec == intended_spec, \
            f"Sharding mismatch for array: Actual={actual_leaf.sharding.spec}, Expected={intended_spec}"
    elif isinstance(actual_leaf, (int, jnp.integer, float, jnp.floating)):
        assert intended_spec == P(), \
            f"Sharding mismatch for scalar: Expected=P(), Got={intended_spec}"
    else:
        pass

jax.tree_util.tree_map(check_opt_state_leaf, opt_state, opt_state_sharding_tree)

print("Optimizer state sharding verified.")




Getting abstract optimizer state structure...
Abstract Opt State Structure:
(ScaleByAdamState(count=ShapeDtypeStruct(shape=(), dtype=int32), mu={'dense_layer_1': {'bias': ShapeDtypeStruct(shape=(128,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(64, 128), dtype=float32)}, 'dense_layer_2': {'bias': ShapeDtypeStruct(shape=(10,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(128, 10), dtype=float32)}}, nu={'dense_layer_1': {'bias': ShapeDtypeStruct(shape=(128,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(64, 128), dtype=float32)}, 'dense_layer_2': {'bias': ShapeDtypeStruct(shape=(10,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(128, 10), dtype=float32)}}), EmptyState())

Creating Optimizer State Sharding Tree (handling scalars)...
Optimizer State Sharding Tree:
(ScaleByAdamState(count=PartitionSpec(), mu={'dense_layer_1': {'bias': PartitionSpec(), 'kernel': PartitionSpec()}, 'dense_layer_2': {'bias': PartitionSpec(), 'kernel': PartitionSpec()}}, nu={'dense_layer_1

In [12]:

# --- 5. Prepare and Shard Data Batch ---
# Create *global* batch data on Host (CPU) first
dummy_global_x = jax.random.normal(data_key, (GLOBAL_BATCH_SIZE, INPUT_FEATURES))
# Use integers for typical classification labels
dummy_global_y = jax.random.randint(data_key, (GLOBAL_BATCH_SIZE,), 0, OUTPUT_DIM)

# Shard the global batch onto devices using jax.device_put
# This transfers data from host and distributes it according to the specs.
print(f"\nSharding initial data batch (Global shape: X={dummy_global_x.shape}, Y={dummy_global_y.shape})")

# It's often good practice to JIT the sharding if done repeatedly in a loop
#@functools.partial(jax.jit, static_argnums=(1,2)) # Shardings are static
def shard_batch(batch_data, data_named_sharding, label_named_sharding):
    x, y = batch_data
    x = jax.device_put(x, data_named_sharding)
    y = jax.device_put(y, label_named_sharding)
    return x, y

sharded_x, sharded_y = shard_batch(
    (dummy_global_x, dummy_global_y),
    data_sharding,      # Use the data sharding spec P('data', None)
    label_sharding      # Use the label sharding spec P('data',)
)

print("Verifying data sharding:")
print(f"  Sharded X shape per device: {sharded_x.shape}") # Shape reflects per-device slice
print(f"  Sharded X sharding: {sharded_x.sharding}")
assert sharded_x.sharding == data_sharding
assert sharded_x.shape == (DEVICE_BATCH_SIZE, INPUT_FEATURES)

print(f"  Sharded Y shape per device: {sharded_y.shape}")
print(f"  Sharded Y sharding: {sharded_y.sharding}")
assert sharded_y.sharding == label_sharding
assert sharded_y.shape == (DEVICE_BATCH_SIZE,)





Sharding initial data batch (Global shape: X=(32, 64), Y=(32,))
Verifying data sharding:
  Sharded X shape per device: (32, 64)
  Sharded X sharding: NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec('data', None), memory_kind=unpinned_host)
  Sharded Y shape per device: (32,)
  Sharded Y sharding: NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec('data',), memory_kind=unpinned_host)


In [7]:
# --- 6. Define the Training Step ---

# Loss function (cross-entropy) - operates on local batch
def loss_fn(params, batch_x, batch_y, model_obj):
    logits = model_obj.apply({'params': params}, batch_x)
    one_hot_labels = jax.nn.one_hot(batch_y, num_classes=OUTPUT_DIM)
    # Calculate loss on the local device's batch slice
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels))
    return loss

# Define the update step, JITted with sharding specifications
@functools.partial(jax.jit,
                   # Specify input shardings (must match how data/state IS sharded)
                   in_shardings=(params_sharding_tree,  # Replicated
                                 opt_state_sharding_tree,# Replicated
                                 data_sharding,         # Sharded P('data', None)
                                 label_sharding),       # Sharded P('data',)
                   # Specify output shardings (how results SHOULD BE sharded)
                   out_shardings=(params_sharding_tree, # Replicated
                                  opt_state_sharding_tree,# Replicated
                                  replicated_sharding)  # Loss (scalar, replicated)
                   )
def train_step(current_params, current_opt_state, batch_x, batch_y):
    # Calculate loss and gradients on the local batch using the local param replica.
    # Gradients will initially be replicated (like params).
    loss, grads = jax.value_and_grad(loss_fn)(current_params, batch_x, batch_y, model)

    # *** CRUCIAL FOR REPLICATION / DATA PARALLELISM ***
    # Average gradients across the 'data' mesh axis using pmean (parameter mean).
    # This computes the mean of 'grads' across all devices in the 'data' axis.
    # The result is replicated back to all devices in that axis.
    averaged_grads = jax.lax.pmean(grads, axis_name='data')

    # Also average the loss across devices for consistent reporting
    averaged_loss = jax.lax.pmean(loss, axis_name='data')

    # Compute updates using the optimizer (using the *averaged* gradients)
    updates, new_opt_state = optimizer.update(averaged_grads, current_opt_state, current_params)

    # Apply updates to parameters. Since params/updates are replicated,
    # this happens identically on all devices, keeping them synchronized.
    new_params = optax.apply_updates(current_params, updates)

    return new_params, new_opt_state, averaged_loss

# --- 7. Training Loop ---
print("\nStarting training loop...")
for step in range(NUM_STEPS):
    # In a real loop, load and shard a *new* batch here.
    # We reuse the same sharded batch for simplicity.
    current_x = sharded_x
    current_y = sharded_y

    # Execute the JITted training step
    params, opt_state, loss = train_step(params, opt_state, current_x, current_y)

    if step % 10 == 0 or step == NUM_STEPS - 1:
        # Loss is replicated, so we can just print it from any device (device 0 implicitly)
        # Use .item() to get scalar value from the JAX array.
        print(f"Step: {step:3d}, Loss: {loss.item():.4f}")

print("Training finished.")

# --- 8. Final Verification (Optional) ---
print("\nFinal Parameter Sharding Check:")
assert all(leaf.sharding == replicated_sharding
           for leaf in jax.tree_util.tree_leaves(params))
print("All parameters remain replicated.")



SyntaxError: incomplete input (2588790154.py, line 69)