# Combining shard_map and vmap

This notebook demonstrates how `jbar` can simultaneously track both device-level parallelism (via shard_map) and task-level parallelism (via vmap). The progress meter displays information about both contexts.

In [1]:
import time
import jax
jax.config.update('jax_platforms', 'cpu')
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_num_cpu_devices', 8)

import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import shard_map
from jbar import TqdmProgressMeter
from functools import partial
from jax import lax

def runtime_sleep(seconds):
    time.sleep(float(seconds))

## Setup Virtual Devices

We create 8 virtual CPU devices arranged in a 2×4 mesh.

In [2]:
# Create 8 virtual CPU devices
devices = jax.devices()
mesh = jax.make_mesh((2, 4), axis_names=('x', 'y'))
jax.set_mesh(mesh)
print(f"Mesh shape: {mesh.shape}")
print(f"Total devices: {len(devices)}")

Mesh shape: OrderedDict([('x', 2), ('y', 4)])
Total devices: 8


  mesh = jax.make_mesh((2, 4), axis_names=('x', 'y'))


## Vmap Over Sharded Computation

We'll vmap over 3 batches, where each batch runs a sharded computation across 8 devices.
The progress meter tracks both:
- **Device context**: Which device (rank 0-7)
- **Task context**: Which vmap batch (task 1-3)

Total parallel executions: 3 tasks × 8 devices = 24 concurrent computations!

In [3]:
nb_elements = 10
pbar = TqdmProgressMeter(
    total=nb_elements,
    description_callback=lambda state, args: (
        f"Device {int(state.rank) + 1}/{int(state.size)} "
        f"Task {int(state.v_index) + 1}/{int(state.v_size)}"
    ),
    refresh_steps=2
)

@partial(jax.shard_map, in_specs=(P('x', 'y', None), P()), out_specs=P('x'))
def sharded_task(x_sharded, task_id):
    """A single task that runs sharded computation.
    
    Args:
        x_sharded: Data sharded across devices (2, 4, 10)
        task_id: Scalar identifying this vmap task
    """
    
    state = pbar.init(vmapped_element=task_id, spec=P('x', 'y'))
        
    def scan_body(carry, i):
        (cum, state) = carry
        # Use task_id in computation
        cum += i**2
        state = pbar.step(state, description_args=())
        jax.debug.callback(runtime_sleep, 0.05)
        return (cum, state), cum
    
    _, cum_results = jax.lax.scan(
        scan_body,
        (0.0, state),
        jnp.arange(nb_elements)
    )
    pbar.close(state)
    return cum_results
    

# Create data for 3 vmap tasks
arr = jnp.linspace(0.0, 10.0, nb_elements)
# Shape: (2, 4, 10) - one per device in the mesh
base_data = jnp.stack([
    jnp.stack([arr + i + j*10 for j in range(4)], axis=0)
    for i in range(2)
], axis=0)

base_data = lax.with_sharding_constraint(base_data, P('x', 'y', None))
# Create 3 different versions for vmap
vmap_data = jnp.stack([base_data, base_data + 100, base_data + 200], axis=0)  # (3, 2, 4, 10)
task_ids = jnp.array([0.0, 1.0, 2.0])  # Task identifiers

print(f"Vmap data shape: {vmap_data.shape}")
print(f"Task IDs shape: {task_ids.shape}")
print("\nRunning vmapped sharded computation...")
print("This will show progress for device + task combinations\n")

# Vmap over the sharded computation
results = jax.vmap(sharded_task)(vmap_data, task_ids)

print(f"\nDone! Results shape: {results.shape}")

Vmap data shape: (3, 2, 4, 10)
Task IDs shape: (3,)

Running vmapped sharded computation...
This will show progress for device + task combinations



0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]


Done! Results shape: (3, 20)


## With Limited Bars

Since we have 24 total concurrent executions (3 tasks × 8 devices), we can limit the number of progress bars to keep the display manageable.

In [6]:
pbar_limited = TqdmProgressMeter(
    total=nb_elements,
    description_callback=lambda state, args: (
        f"D{int(state.rank) + 1}/{int(state.size)} "
        f"T{int(state.v_index) + 1}/{int(state.v_size)}"
    ),
    max_bars=1,  # Only show 1 bar out of 24 total executions
    refresh_steps=2
)
@partial(jax.shard_map, in_specs=(P('x', 'y', None), P()), out_specs=P('x'))
def sharded_task(x_sharded, task_id):
    """A single task that runs sharded computation.
    
    Args:
        x_sharded: Data sharded across devices (2, 4, 10)
        task_id: Scalar identifying this vmap task
    """
    
    state = pbar_limited.init(vmapped_element=task_id, spec=P('x', 'y'))
        
    def scan_body(carry, i):
        (cum, state) = carry
        # Use task_id in computation
        cum += i**2
        state = pbar_limited.step(state, description_args=())
        jax.debug.callback(runtime_sleep, 0.05)
        return (cum, state), cum
    
    _, cum_results = jax.lax.scan(
        scan_body,
        (0.0, state),
        jnp.arange(nb_elements)
    )
    pbar_limited.close(state)
    return cum_results
    

# Create data for 3 vmap tasks
arr = jnp.linspace(0.0, 10.0, nb_elements)
# Shape: (2, 4, 10) - one per device in the mesh
base_data = jnp.stack([
    jnp.stack([arr + i + j*10 for j in range(4)], axis=0)
    for i in range(2)
], axis=0)

base_data = lax.with_sharding_constraint(base_data, P('x', 'y', None))
# Create 3 different versions for vmap
vmap_data = jnp.stack([base_data, base_data + 100, base_data + 200], axis=0)  # (3, 2, 4, 10)
task_ids = jnp.array([0.0, 1.0, 2.0])  # Task identifiers

print(f"Vmap data shape: {vmap_data.shape}")
print(f"Task IDs shape: {task_ids.shape}")
print("\nRunning vmapped sharded computation...")
print("This will show progress for device + task combinations\n")

# Vmap over the sharded computation
results = jax.vmap(sharded_task)(vmap_data, task_ids)

print(f"\nDone! Results shape: {results.shape}")

Vmap data shape: (3, 2, 4, 10)
Task IDs shape: (3,)

Running vmapped sharded computation...
This will show progress for device + task combinations



0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]


Done! Results shape: (3, 20)
