# Nested Parallelism with jaxDecomp

This notebook demonstrates how to perform batched (or nested parallelism) FFTs using `jaxDecomp`.
We will utilize `jax.vmap` to vectorize over a chain axis, allowing JAX to handle batch parallelism while `jaxDecomp` handles the distributed FFTs within each group.

In [1]:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"  # Force 16 virtual devices for simulation
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["JAX_PLATFORMS"] = "cpu"

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np
import jaxdecomp.fft as jdfft
from functools import partial
from jax.experimental.multihost_utils import process_allgather # Helper to gather data from all devices to host
gather = partial(process_allgather, tiled=True) # Helper to gather data from all devices to host

## 1. Define the Full Mesh

We treat the 16 devices as 2 Groups (c) of 8 Devices (x, y).

In [2]:
devices = jax.devices()
mesh = Mesh(np.array(devices).reshape(2, 4, 2), ("c", "x", "y")) # 2 groups * 8 devices/group = 16 total
mesh_inner = Mesh(np.array(devices[:8]).reshape(4, 2), ("x", "y"))  # For per-group logic # 2 groups * 8 devices/group = 16 total

ERROR:2026-02-07 15:32:20,296:jax._src.xla_bridge:491: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 489, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax_plugins/xla_cuda12/__init__.py", line 285, in _check_cuda_versions
    local_device_count = cuda_versions.cuda_device_count()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: CUDA_ERROR_UNKNOWN


## 2. Define the Per-Group Logic

This function expects a SINGLE group data (Chain dimension is stripped).
It uses standard Auto-Sharding tools (`jdfft`).

In [3]:
def per_group_logic(arr):
    # arr shape: (Z, Y, X)
    # jdfft expects a global-looking array to apply its constraints.
    # Debug print to verify sharding inside the kernel
    jax.debug.inspect_array_sharding(arr , callback=lambda sharding: print(f"Sharding in per_group_logic: {sharding}"))
    print(f"shape here is {arr.shape}")
    # Perform 3D FFT on the sharded array (distributed across x and y)
    return jdfft.pfft3d(arr)

## 3. Create the Batched Operation

We use `vmap` to vectorize over the chain axis.
JAX compiler automatically handles the Batch Parallelism for c and the Tensor Parallelism for x and y.

In [4]:
batched_op = jax.jit(jax.vmap(per_group_logic)) # JIT compile the vmapped function for performance

## 4. Run

We initialize global data with shape (Chain=2, Z=32, Y=32, X=32) and apply sharding constraints.
JAX will run 2 distributed FFTs in parallel:
- Group 0 handles Batch 0 (Distributed over its 8 GPUs)
- Group 1 handles Batch 1 (Distributed over its 8 GPUs)

In [5]:
# Global Data: (Chain=2, Z=32, Y=32, X=32)
global_data = jax.random.normal(jax.random.key(0), (2, 32, 32, 32)) # Create random data on host

# Constraint: Shard C over c, Z over x, Y over y
input_sharding = NamedSharding(mesh, P("c", "x", "y")) # Define global sharding: partitioned along c, x, y
sharded_global_data = jax.lax.with_sharding_constraint(global_data, input_sharding) # Create random data on host

print(f"Input Sharding: {sharded_global_data.sharding}")

# Execute
# Run the batched FFT
output = batched_op(sharded_global_data)

# Extract first batch element for verification
inner = global_data[0]  # Shape: (Z=32, Y=32, X=32) for the first group
inner = jax.lax.with_sharding_constraint(inner, NamedSharding(mesh_inner, P("x", "y"))) # Define global sharding: partitioned along c, x, y
# Run the batched FFT
single_group_output = per_group_logic(inner)

print(f"Output Shape: {output.shape}")
print(f"Output Sharding: {output.sharding}")

Input Sharding: NamedSharding(mesh=Mesh('c': 2, 'x': 4, 'y': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('c', 'x', 'y'), memory_kind=device)
shape here is (32, 32, 32)
Sharding in per_group_logic: NamedSharding(mesh=Mesh('c': 2, 'x': 4, 'y': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('c', 'x', 'y'), memory_kind=device)


Sharding in per_group_logic: NamedSharding(mesh=Mesh('x': 4, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
shape here is (32, 32, 32)
Output Shape: (2, 32, 32, 32)
Output Sharding: NamedSharding(mesh=Mesh('c': 2, 'x': 4, 'y': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('c', 'x', 'y'), memory_kind=device)


## 5. Verification

We verify if the batched output is consistent with the result from a single group logic.

In [6]:
batched_out = gather(output) # Collect distributed results to host for comparison
single_out = gather(single_group_output) # Collect distributed results to host for comparison

print(f"is output close to single group result? {jnp.allclose(batched_out[0], single_out)}")

is output close to single group result? True


We can also checkout the sharding of the output to undestand more
in the next cell this is a 2D slice of the batch axis and the first of the sharding axes (z)

In [7]:
jax.debug.visualize_array_sharding(output[..., 0 , 0])

If we look at the sharding of the output and split it on the batch axis we can see that there are pairs of groups : 

```
0 - 8
1 - 9
2 - 10
etc ...
```

That run independently the same logic on different batches. Each group is sharding the data across the z and y axis as expected.

In [8]:
jax.debug.visualize_array_sharding(output[0, ... , 0])

## 6. HLO Inspection

We inspect the HLO to verify the independence of the groups.

In [9]:
HLO = batched_op.lower(sharded_global_data).compile().as_text()
# We look for the all-to-all instruction which handles the FFT data redistribution
# The replica_groups should show that devices communicate only within their group (c dimension)
print("Relevant HLO instruction showing independent groups:")
found = False
for line in HLO.split("\n"):
    if "all-to-all" in line and "replica_groups" in line:
        print(line.strip())
        found = True
if not found:
    print("No all-to-all with replica_groups found. Check HLO for other communication primitives.")

Relevant HLO instruction showing independent groups:
%all-to-all = (c64[1,8,16,16]{3,2,1,0}, c64[1,8,16,16]{3,2,1,0}) all-to-all(%wrapped_slice, %wrapped_slice.1), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7},{8,9},{10,11},{12,13},{14,15}}, metadata={op_name="jit(per_group_logic)/vmap(jit(_do_pfft))/jit(pfft_impl)/custom_partitioning" stack_frame_id=12}
%all-to-all.1 = (c64[1,16,8,8]{3,2,1,0}, c64[1,16,8,8]{3,2,1,0}, c64[1,16,8,8]{3,2,1,0}, c64[1,16,8,8]{3,2,1,0}) all-to-all(%wrapped_slice.2, %wrapped_slice.3, %wrapped_slice.4, %wrapped_slice.5), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7},{8,10,12,14},{9,11,13,15}}, metadata={op_name="jit(per_group_logic)/vmap(jit(_do_pfft))/jit(pfft_impl)/custom_partitioning" stack_frame_id=12}
