# Basic Usage of jax-progress

This notebook demonstrates how to use `jax-progress` with JAX `scan` loops, including support for `vmap`.

In [1]:
import time
import jax
import jax.numpy as jnp
from jax_progress import TqdmProgressMeter

def runtime_sleep(seconds):
    time.sleep(float(seconds))

## Single Scan Loop

Here we show a simple progress bar for a single `lax.scan` loop.

In [4]:
nb_elements = 10
pbar = TqdmProgressMeter(
    total=nb_elements, 
    description_callback=lambda state, args: f"Val: {args[0]:.2f}, Cum: {float(args[1]):.2f}",
    refresh_steps=2,
    max_bars=2
)

def scanning(elements):
    state = pbar.init(vmapped_element=elements)

    def scan_body(carry, x):
        (cum, state) = carry
        cum += x**2
        state = pbar.step(state, description_args=(x, cum))
        jax.debug.callback(runtime_sleep, 0.1)
        return (cum , state), cum

    _, cum_results = jax.lax.scan(scan_body, (0.0, state), elements)

    pbar.close(state)
    return cum_results

arr = jnp.linspace(0.0, 10.0, nb_elements)
_ = scanning(arr)

v_index: 0, v_size: 1


0%|          | 0.0/10.0 [00:00<?]

## Vmapped Scan Loop

Using `vmap` on a function with a progress bar. 
If `max_bars` is not set (or high enough), every vmapped instance gets a bar.

In [6]:
arr_stacked = jnp.stack([arr, arr + 1.0, arr + 2.0], axis=0)
_ = jax.vmap(scanning)(arr_stacked)

v_index: 0, v_size: 3
v_index: -1, v_size: 3
v_index: -1, v_size: 3


0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]

## Vmapped Scan with Limited Bars

We can limit the number of bars displayed using `max_bars`.
Bars are assigned to tasks (prioritizing slower ones if speeds differ, though here they are equal).

In [9]:
from tqdm.auto import tqdm
tqdm?

[31mInit signature:[39m tqdm(*args, **kwargs)
[31mDocstring:[39m      Experimental IPython/Jupyter Notebook widget using tqdm!
[31mInit docstring:[39m
Supports the usual `tqdm.tqdm` parameters as well as those listed below.

Parameters
----------
display  : Whether to call `display(self.container)` immediately
    [default: True].
[31mFile:[39m           ~/micromamba/envs/fg/lib/python3.11/site-packages/tqdm/auto.py
[31mType:[39m           type
[31mSubclasses:[39m     

In [7]:
pbar_limited = TqdmProgressMeter(
    total=nb_elements, 
    description_callback=lambda state, args: f"Task {int(state.v_index) + 1}/{int(state.v_size)}",
    max_bars=2, # Only show 2 bars for 5 tasks
    refresh_steps=2
)

def scanning_limited(elements):
    state = pbar_limited.init(vmapped_element=elements)

    def scan_body(carry, x):
        (cum, state) = carry
        cum += x**2
        state = pbar_limited.step(state, description_args=(x, cum))
        jax.debug.callback(runtime_sleep, 0.1)
        return (cum , state), cum

    _, cum_results = jax.lax.scan(scan_body, (0.0, state), elements)

    pbar_limited.close(state)
    return cum_results

arr_5 = jnp.stack([arr]*5, axis=0)
_ = jax.vmap(scanning_limited)(arr_5)

v_index: 0, v_size: 5
v_index: -1, v_size: 5
v_index: -1, v_size: 5
v_index: -1, v_size: 5
v_index: -1, v_size: 5


0%|          | 0.0/10.0 [00:00<?]

0%|          | 0.0/10.0 [00:00<?]