# Nested Vmap with jbar

This notebook demonstrates how `jbar` works with nested `vmap` operations. Each vmap level sees only its immediate context.

In [1]:
import time
import jax
import jax.numpy as jnp
from jbar import TqdmProgressMeter

def runtime_sleep(seconds):
    time.sleep(float(seconds))

## Simple Nested Vmap

Here we nest two vmaps:
- Outer vmap: 3 tasks
- Inner vmap: 4 tasks each

The progress meter tracks the inner vmap. Note that each vmap level sees only its immediate context - the inner vmap sees v_size=4, not the total 12 tasks.

In [4]:
outer_elements = 10
inner_elements = 4
outer_size = 3
inner_size = 4

nb_elements = outer_elements * inner_elements

outer_pbar = TqdmProgressMeter(
    total=outer_size,
    description_callback=lambda state, args: f"Outer task {int(state.v_index) + 1}/{int(state.v_size)}",
    refresh_steps=1,
    max_bars=9
)

inner_pbar = TqdmProgressMeter(
    total=inner_size,  # Number of inner tasks
    description_callback=lambda state, args: f"Group {int(args)} => Inner task {int(state.v_index) + 1}/{int(state.v_size)}",
    refresh_steps=1,
    max_bars=11,
    leave=True
)

def inner_computation(inner_arr , group_id):
    """Inner function that will be vmapped."""
    state = inner_pbar.init(vmapped_element=inner_arr)
    
    def scan_body(carry, x):
        (cum, state) = carry
        cum += x**2
        state = inner_pbar.step(state, description_args=group_id)
        jax.debug.callback(runtime_sleep, 1)
        return (cum, state), cum
    
    (_ , state), cum_results = jax.lax.scan(scan_body, (0.0, state), inner_arr)
    inner_pbar.close(state)
    return cum_results

def outer_computation(outer_arr):
    """Outer function that vmaps over inner_computation."""
    state = outer_pbar.init(vmapped_element=outer_arr)
    def scan_body(carry, x):
        (counter , results, state) = carry
        inner_results = jax.vmap(inner_computation , in_axes=(0, None))(x , counter)
        results += inner_results
        state = outer_pbar.step(state, description_args=())
        return (counter + 1 , results, state), inner_results

    init_results = jnp.zeros_like(outer_arr)
    (_ , _ , state), all_results = jax.lax.scan(scan_body, (0 , init_results, state), outer_arr)
    outer_pbar.close(state)

    return all_results


# Create nested data: 3 outer tasks, each with 4 inner tasks
arr = jnp.linspace(0.0, 10.0, nb_elements).reshape((outer_elements, inner_elements))
arr_stacks = jnp.stack([arr] * outer_size)  # Shape (4, 10, 4)
arr_stacks = jnp.stack([arr_stacks] * inner_size)  # Shape (10, 4, 10, 4)

print("Running nested vmap...")
results = jax.vmap(outer_computation)(arr_stacks)
print(f"Done! Results shape: {results.shape}")
inner_pbar.terminate()
outer_pbar.terminate()

Running nested vmap...


0%|          | 0.0/3.0 [00:00<?]

0%|          | 0.0/3.0 [00:00<?]

0%|          | 0.0/3.0 [00:00<?]

0%|          | 0.0/3.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

0%|          | 0.0/4.0 [00:00<?]

Done! Results shape: (4, 3, 10, 4)
