# Distributed PM Simulation

JAX supports distributing array computations across multiple devices (GPUs/TPUs) using its sharding API. The `fwd_model_tools` library accepts an optional `sharding` parameter in `gaussian_initial_conditions` that partitions the initial field across devices. All downstream operations (`lpt`, `nbody`, painting) automatically respect this sharding.

This notebook demonstrates the distributed workflow using fake CPU devices for testing. On a multi-GPU machine, replace the `XLA_FLAGS` setup with your actual device mesh.

Make sure that you installed rich

`pip install rich`

## Setup

The `XLA_FLAGS` environment variable must be set **before** importing JAX to create fake devices on CPU.

In [1]:
import os

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.97'
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'


import jax
jax.config.update("jax_num_cpu_devices", 4)  # Set this to the number of CPU cores you want to use
jax.config.update('jax_platform_name', 'cpu')  # or 'cpu' if you want to run on CPU
import jax.numpy as jnp
import jax_cosmo as jc
import matplotlib.pyplot as plt
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P , AxisType

import fwd_model_tools as ffi

print(f"Number of devices: {jax.device_count()}")
print(f"Devices {jax.devices()}")
jax.print_environment_info()

jax.config.update('jax_enable_x64' , False)

  from pkg_resources import DistributionNotFound, get_distribution
ERROR:2026-02-24 14:34:39,129:jax._src.xla_bridge:491: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 489, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax_plugins/xla_cuda12/__init__.py", line 285, in _check_cuda_versions
    local_device_count = cuda_versions.cuda_device_count()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: CUDA_ERROR_UNKNOWN


Number of devices: 4
Devices [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
jax:    0.9.0.1
jaxlib: 0.9.0.1
numpy:  2.4.2
python: 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:17) [GCC 12.2.0]
device info: cpu-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='apc2324', release='6.8.0-100-generic', version='#100-Ubuntu SMP PREEMPT_DYNAMIC Tue Jan 13 16:40:06 UTC 2026', machine='x86_64')
XLA_PYTHON_CLIENT_MEM_FRACTION=0.97

$ nvidia-smi
Tue Feb 24 14:34:41 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.211.01             Driver Version: 570.211.01     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |


## Create Device Mesh and Sharding

A `Mesh` maps physical devices to named axes. The `PartitionSpec` then describes how array dimensions map to mesh axes. The `jaxpm` distributed backend expects a **2D mesh** so that the first two spatial dimensions of the 3D field are each partitioned across one mesh axis.

In [2]:
import numpy as np

mesh = jax.make_mesh((4 , 1) , ('x', 'y') , axis_types=(AxisType.Auto , AxisType.Auto))
sharding = NamedSharding(mesh, P('x', 'y'))

print(f"Mesh shape: {mesh.shape}")
print(f"Partition spec: {sharding.spec}")

Mesh shape: OrderedDict([('x', 4), ('y', 1)])
Partition spec: PartitionSpec('x', 'y')


## Distributed Initial Conditions

Pass the `sharding` parameter to `gaussian_initial_conditions`. The returned `DensityField` stores the sharding info and all subsequent operations will respect it.

In [3]:
key = jax.random.PRNGKey(0)
mesh_size = (120 , 120 , 120)
halo_size = (100 , 100)
box_size = (1000., 1000., 1000.)
nside = 512
flatsky = (512 , 512)
field_size = (10. , 10.)
cosmo = jc.Planck18()

initial_field = ffi.gaussian_initial_conditions(
    key, mesh_size, box_size,
    cosmo=cosmo,
    nside=nside,
    flatsky_npix=flatsky,
    field_size=field_size,
    sharding=sharding,
    halo_size=halo_size
)

print(f"Initial field shape: {initial_field.array.shape}")
print(f"Sharding: {initial_field.array.sharding}")

Initial field shape: (120, 120, 120)
Sharding: NamedSharding(mesh=Mesh('x': 4, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('x',), memory_kind=device)


## Run Distributed PM Pipeline

The pipeline is identical to single-device usage. JAX handles communication (halo exchanges, all-reduces) automatically based on the sharding.

In [4]:
# LPT initialization
dx, p = ffi.lpt(cosmo, initial_field, ts=0.1, order=1)

print(f"Displacement sharding: {dx.array.sharding}")
print(f"Momentum sharding: {p.array.sharding}")
jax.debug.visualize_array_sharding(dx.array[... , 0 , 0]) 

Displacement sharding: NamedSharding(mesh=Mesh('x': 4, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('x',), memory_kind=device)
Momentum sharding: NamedSharding(mesh=Mesh('x': 4, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('x',), memory_kind=device)


In [5]:
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="flat")),
)

densities = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
).block_until_ready()

print(f"Lightcone type: {type(densities).__name__}")
print(f"Lightcone shape: {densities.shape}")
print(f"Density sharding: {densities.array.sharding}")
jax.debug.visualize_array_sharding(densities.array[0 , ...])

Lightcone type: FlatDensity
Lightcone shape: (4, 512, 512)
Density sharding: NamedSharding(mesh=Mesh('x': 4, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'x'), memory_kind=device)


In [6]:
del densities
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="spherical", scheme="bilinear")),
)

lightcone = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
)

print(f"Lightcone type: {type(lightcone).__name__}")
print(f"Lightcone shape: {lightcone.shape}")
jax.debug.visualize_array_sharding(lightcone.array[0 , ...])

Lightcone type: SphericalDensity
Lightcone shape: (4, 3145728)


In [7]:
lightcone.sharding

NamedSharding(mesh=Mesh('x': 4, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)

In [7]:
del lightcone
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="spherical", scheme="ngp")),
)

lightcone = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
)

print(f"Lightcone type: {type(lightcone).__name__}")
print(f"Lightcone shape: {lightcone.shape}")
jax.debug.visualize_array_sharding(lightcone.array[0 , ...])

Lightcone type: SphericalDensity
Lightcone shape: (4, 3145728)


In [11]:
del lightcone
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="density")),
)

densities = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
)

print(f"Lightcone type: {type(densities).__name__}")
print(f"Lightcone shape: {densities.shape}")
print(f"Density sharding: {densities.array.sharding}")
jax.debug.visualize_array_sharding(densities.array[0 , ... , 0])

Lightcone type: DensityField
Lightcone shape: (4, 1200, 1200, 1200)
Density sharding: NamedSharding(mesh=Mesh('x': 4, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'x'), memory_kind=device)
