# Using jbar with shard_map

This notebook demonstrates how `jbar` works with `jax.shard_map` for distributed computation across multiple devices. We use virtual CPU devices for demonstration purposes.

In [1]:
import time
import jax
jax.config.update('jax_platforms', 'cpu')
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_num_cpu_devices', 8)

import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import shard_map
from jbar import TqdmProgressMeter
from functools import partial
from jax import lax

def runtime_sleep(seconds):
    time.sleep(float(seconds))

## Setup Virtual Devices

We create 8 virtual CPU devices to demonstrate device sharding without requiring actual GPUs or TPUs.

In [2]:
# Create 8 virtual CPU devices
devices = jax.devices()
print(f"Available devices: {len(devices)}")
print(f"Devices: {devices}")

Available devices: 8
Devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


## Single Axis Sharding

We create a 1D mesh with 8 devices and shard computation across the 'x' axis.
The progress meter tracks which device is reporting progress.

In [3]:
# Create 1D mesh with 8 devices
mesh = jax.make_mesh((8,), axis_names=('x',))
sharding = NamedSharding(mesh, P('x'))
jax.set_mesh(mesh)
print(f"Mesh shape: {mesh.shape}")

nb_elements = 10
pbar = TqdmProgressMeter(
    total=nb_elements,
    description_callback=lambda state, args: f"Device {int(state.rank) + 1}/{int(state.size)}",
    refresh_steps=2
)

@partial(shard_map , in_specs=P('x'), out_specs=P('x'),)
def sharded_computation(x):
    """Computation that runs across sharded devices."""
    
    state = pbar.init(spec=P('x'))
    def scan_body(carry, i):
        (cum, state) = carry
        cum += i
        state = pbar.step(state, description_args=())
        jax.debug.callback(runtime_sleep, 1)
        return (cum, state), cum
    
    _, cum_results = jax.lax.scan(
        scan_body, 
        (0.0, state), 
        jnp.arange(nb_elements)
    )
    pbar.close(state)
    return cum_results
    

# Create input data sharded across devices
arr = jnp.linspace(0.0, 10.0, nb_elements)
input_data = jnp.stack([arr + i for i in range(8)], axis=0)  # (8, 10)
input_data = lax.with_sharding_constraint(input_data, sharding)

print(f"Input shape: {input_data.shape}")
print("Running sharded computation...")
results = sharded_computation(input_data).block_until_ready()
print(f"Done! Results shape: {results.shape}")
pbar.terminate()

  mesh = jax.make_mesh((8,), axis_names=('x',))


Mesh shape: OrderedDict([('x', 8)])
Input shape: (8, 10)
Running sharded computation...


0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

Done! Results shape: (80,)


## 2D Mesh Sharding

We create a 2D mesh with shape (2, 4) and shard across both axes.
This demonstrates how the progress meter calculates device rank in multi-dimensional meshes.

In [4]:
# Create 2D mesh: 2x4 = 8 devices
mesh_2d = jax.make_mesh((2, 4), axis_names=('x', 'y'))
jax.set_mesh(mesh_2d)
sharding_2d = NamedSharding(mesh_2d, P('x', 'y'))
print(f"2D Mesh shape: {mesh_2d.shape}")

pbar_2d = TqdmProgressMeter(
    total=nb_elements,
    description_callback=lambda state, args: f"Device {int(state.rank) + 1}/{int(state.size)}",
    refresh_steps=3
)

@partial(shard_map , in_specs=P('x', 'y', None), out_specs=P('x',),)
def sharded_computation_2d(x):
    """Computation on 2D mesh."""
    
    state = pbar_2d.init(spec=P('x', 'y'))
        
    def scan_body(carry, i):
        (cum, state) = carry
        cum += i**2
        state = pbar_2d.step(state, description_args=())
        jax.debug.callback(runtime_sleep, 1)
        return (cum, state), cum
    
    _, cum_results = jax.lax.scan(
        scan_body,
        (0.0, state),
        jnp.arange(nb_elements)
    )
    pbar_2d.close(state)
    return cum_results

# Create input data for 2D mesh: (2, 4, 10)
input_2d = jnp.stack([
    jnp.stack([arr + i + j*10 for j in range(4)], axis=0)
    for i in range(2)
], axis=0)

input_2d = lax.with_sharding_constraint(input_2d, sharding_2d)

print(f"Input shape: {input_2d.shape}")
print("Running 2D sharded computation...")
results_2d = sharded_computation_2d(input_2d).block_until_ready()
print(f"Done! Results shape: {results_2d.shape}")

  mesh_2d = jax.make_mesh((2, 4), axis_names=('x', 'y'))


2D Mesh shape: OrderedDict([('x', 2), ('y', 4)])
Input shape: (2, 4, 10)
Running 2D sharded computation...


0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

Done! Results shape: (20,)
