# Distributed PM Simulation

JAX supports distributing array computations across multiple devices (GPUs/TPUs) using its sharding API. The `fwd_model_tools` library accepts an optional `sharding` parameter in `gaussian_initial_conditions` that partitions the initial field across devices. All downstream operations (`lpt`, `nbody`, painting) automatically respect this sharding.

This notebook demonstrates the distributed workflow using fake CPU devices for testing. On a multi-GPU machine, replace the `XLA_FLAGS` setup with your actual device mesh.

## Setup

The `XLA_FLAGS` environment variable must be set **before** importing JAX to create fake devices on CPU.

In [3]:
!pip install -q rich

  pid, fd = os.forkpty()


In [4]:
import os

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.97'
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

import jax
jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
import jax_cosmo as jc
import matplotlib.pyplot as plt
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P , AxisType

import fwd_model_tools as ffi

print(f"Number of devices: {jax.device_count()}")
print(f"Devices: {jax.devices()}")

Number of devices: 4
Devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


## Create Device Mesh and Sharding

A `Mesh` maps physical devices to named axes. The `PartitionSpec` then describes how array dimensions map to mesh axes. The `jaxpm` distributed backend expects a **2D mesh** so that the first two spatial dimensions of the 3D field are each partitioned across one mesh axis.

In [5]:
import numpy as np

mesh = jax.make_mesh((2 , 2) , ('x', 'y') , axis_types=(AxisType.Auto , AxisType.Auto))
sharding = NamedSharding(mesh, P('x', 'y'))

print(f"Mesh shape: {mesh.shape}")
print(f"Partition spec: {sharding.spec}")

Mesh shape: OrderedDict([('x', 2), ('y', 2)])
Partition spec: PartitionSpec('x', 'y')


## Distributed Initial Conditions

Pass the `sharding` parameter to `gaussian_initial_conditions`. The returned `DensityField` stores the sharding info and all subsequent operations will respect it.

In [6]:
key = jax.random.PRNGKey(0)
mesh_size = (64, 64, 64)
halo_size = (4 , 4)
box_size = (500.0, 500.0, 500.0)
nside = 64
flatsky = (64 , 64)
field_size = (10. , 10.)
cosmo = ffi.Planck18()

initial_field = ffi.gaussian_initial_conditions(
    key, mesh_size, box_size,
    cosmo=cosmo,
    nside=nside,
    flatsky_npix=flatsky,
    field_size=field_size,
    sharding=sharding,
    halo_size=halo_size
)

print(f"Initial field shape: {initial_field.array.shape}")
print(f"Sharding: {initial_field.array.sharding}")

Initial field shape: (64, 64, 64)
Sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)


## Run Distributed PM Pipeline

The pipeline is identical to single-device usage. JAX handles communication (halo exchanges, all-reduces) automatically based on the sharding.

In [7]:
# LPT initialization
dx, p = ffi.lpt(cosmo, initial_field, scale_factor_spec=0.1, order=1)

print(f"Displacement sharding: {dx.array.sharding}")
print(f"Momentum sharding: {p.array.sharding}")
jax.debug.visualize_array_sharding(dx.array[... , 0 , 0]) 

Displacement sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
Momentum sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)


In [8]:
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="flat")),
)

densities = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
).block_until_ready()

print(f"Lightcone type: {type(densities).__name__}")
print(f"Lightcone shape: {densities.shape}")
print(f"Density sharding: {densities.array.sharding}")
jax.debug.visualize_array_sharding(densities.array[0 , ...])

positions sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'x', 'y'), memory_kind=device)
xy sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
dz sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'x', 'y'), memory_kind=device)
dx array sharding in NoInterp: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
self.array sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
self.array sharding after cast: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
data sharding before map: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'x', 'y'), memory_kind=device)
Furthest shell

In [9]:
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="spherical", scheme="bilinear")),
)

lightcone = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
)

print(f"Lightcone type: {type(lightcone).__name__}")
print(f"Lightcone shape: {lightcone.shape}")
jax.debug.visualize_array_sharding(lightcone.array)

dx array sharding in NoInterp: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
Lightcone type: SphericalDensity
Lightcone shape: (4, 49152)


Furthest shell is [218.75 156.25  93.75  31.25] Mpc/h, box extends to 250.0 Mpc/h


In [10]:
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="spherical", scheme="ngp")),
)

lightcone = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
)

print(f"Lightcone type: {type(lightcone).__name__}")
print(f"Lightcone shape: {lightcone.shape}")
jax.debug.visualize_array_sharding(lightcone.array)

Painting shell at comoving center 218.75 Mpc/h with width 62.5 Mpc/h and shell index is 0
Painting shell at comoving center 156.25 Mpc/h with width 62.5 Mpc/h and shell index is 1
Painting shell at comoving center 93.75 Mpc/h with width 62.5 Mpc/h and shell index is 2
Painting shell at comoving center 31.25 Mpc/h with width 62.5 Mpc/h and shell index is 3
dx array sharding in NoInterp: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
Lightcone type: SphericalDensity
Lightcone shape: (4, 49152)


Furthest shell is [218.75 156.25  93.75  31.25] Mpc/h, box extends to 250.0 Mpc/h


In [11]:
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="spherical", scheme="rbf_neighbor")),
)

lightcone = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
)

print(f"Lightcone type: {type(lightcone).__name__}")
print(f"Lightcone shape: {lightcone.shape}")
jax.debug.visualize_array_sharding(lightcone.array)



Painting shell at comoving center 218.75 Mpc/h with width 62.5 Mpc/h and shell index is 0
Painting shell at comoving center 156.25 Mpc/h with width 62.5 Mpc/h and shell index is 1
Painting shell at comoving center 93.75 Mpc/h with width 62.5 Mpc/h and shell index is 2
Painting shell at comoving center 31.25 Mpc/h with width 62.5 Mpc/h and shell index is 3
dx array sharding in NoInterp: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
Lightcone type: SphericalDensity
Lightcone shape: (4, 49152)
Furthest shell is [218.75 156.25  93.75  31.25] Mpc/h, box extends to 250.0 Mpc/h


In [12]:
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="density")),
)

densities = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
)

print(f"Lightcone type: {type(densities).__name__}")
print(f"Lightcone shape: {densities.shape}")
print(f"Density sharding: {densities.array.sharding}")
jax.debug.visualize_array_sharding(densities.array[0 , ... , 0])

Painting shell at comoving center 218.75 Mpc/h with width 62.5 Mpc/h and shell index is 0
Painting shell at comoving center 156.25 Mpc/h with width 62.5 Mpc/h and shell index is 1
Painting shell at comoving center 93.75 Mpc/h with width 62.5 Mpc/h and shell index is 2
Painting shell at comoving center 31.25 Mpc/h with width 62.5 Mpc/h and shell index is 3
dx array sharding in NoInterp: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
Lightcone type: DensityField
Lightcone shape: (4, 64, 64, 64)
Density sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'x', 'y'), memory_kind=device)
Furthest shell is [218.75 156.25  93.75  31.25] Mpc/h, box extends to 250.0 Mpc/h


In [13]:
# N-body integration
solver = ffi.ReversibleDoubleKickDrift(
    interp_kernel=ffi.NoInterp(painting=ffi.PaintingOptions(target="flat")),
)

densities = ffi.nbody(
    cosmo, dx, p,
    t1=1.0, dt0=0.05,
    nb_shells=4,
    solver=solver,
).block_until_ready()

print(f"Lightcone type: {type(densities).__name__}")
print(f"Lightcone shape: {densities.shape}")
print(f"Density sharding: {densities.array.sharding}")
jax.debug.visualize_array_sharding(densities.array[0 , ...])

Painting shell at comoving center 218.75 Mpc/h with width 62.5 Mpc/h and shell index is 0
Painting shell at comoving center 156.25 Mpc/h with width 62.5 Mpc/h and shell index is 1
Painting shell at comoving center 93.75 Mpc/h with width 62.5 Mpc/h and shell index is 2
Painting shell at comoving center 31.25 Mpc/h with width 62.5 Mpc/h and shell index is 3
Furthest shell is [218.75 156.25  93.75  31.25] Mpc/h, box extends to 250.0 Mpc/h
Painting shell at comoving center 218.75 Mpc/h with width 62.5 Mpc/h and shell index is 0
Painting shell at comoving center 156.25 Mpc/h with width 62.5 Mpc/h and shell index is 1
Painting shell at comoving center 93.75 Mpc/h with width 62.5 Mpc/h and shell index is 2
Painting shell at comoving center 31.25 Mpc/h with width 62.5 Mpc/h and shell index is 3
Lightcone type: FlatDensity
Lightcone shape: (4, 64, 64)
Density sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'x', 'y'), memory_kind=device)
