In [1]:
import equinox as eqx

In [8]:
import jax.numpy as jnp
from jax import lax
from jax.experimental import io_callback
from tqdm import tqdm


def f(x, y):
    return 2 * x + y


x, y = 3, 4

# lowered = jax.jit(f).lower(x, y)
lowered = eqx.filter_jit(f).lower(x, y)

# Print lowered HLO
print(lowered.as_text())

compiled = lowered.compile()

# Query for cost analysis, print FLOP estimate
# print(compiled.cost_analysis()[0]['flops'])
# 2.0

# Execute the compiled function!
compiled(x, y)

module @jit_f {
  func.func public @main() {
    return
  }
}



TypeError: function compiled for PyTreeDef(((((None,), (None, None)),), {})), called with PyTreeDef(((*, *), {}))

In [None]:
def make_device_calls(num: int):
    pbar = tqdm(total=1)
    print_rate = 20
    assert print_rate != 0

    remainder = num % print_rate

    steps_taken = 0
    num_inits = 0  # Track how many parallel calls we get

    def host_update_pbar(*, num_steps: int):
        nonlocal steps_taken
        nonlocal num_inits
        nonlocal pbar

        if num_steps == 0:
            num_inits += 1
            pbar = tqdm(initial=steps_taken, total=num_inits * num)
        else:
            pbar.update(num_steps)

    def host_close_pbar():
        nonlocal pbar
        pbar.close()

    def device_update_pbar(*, step: int):
        lax.cond(
            step == 0,
            lambda: io_callback(host_update_pbar, None, ordered=False, num_steps=0),
            lambda: None,
        )

        lax.cond(
            # update tqdm every multiple of `print_rate` except at the end
            (step % print_rate == 0) & (step != num - remainder),
            lambda: io_callback(
                host_update_pbar, None, ordered=False, num_steps=print_rate
            ),
            lambda: None,
        )

        lax.cond(
            # update tqdm by `remainder`
            step == num - remainder,
            lambda: io_callback(
                host_update_pbar, None, ordered=False, num_steps=remainder
            ),
            lambda: None,
        )

    def device_close_pbar(step: int):
        lax.cond(
            step == num - 1,
            lambda: io_callback(host_close_pbar, None, ordered=False),
            lambda: None,
        )

    return device_update_pbar, device_close_pbar


def loop_tqdm(num: int):
    update_pbar, close_pbar = make_device_calls(num=num)

    def inner(func):
        # This is body function of loop
        def wrapper(i, val):
            update_pbar(step=i)
            result = func(i, val)
            close_pbar(step=i)
            return result

        return wrapper

    return inner


def scan_tqdm(num: int):
    update_pbar, close_pbar = make_device_calls(num=num)

    def inner(func):
        def wrapper(carry, x):
            assert isinstance(x, tuple)
            step, *x_inner = x

            update_pbar(step=step)
            result = func(carry, x_inner)

            close_pbar(step=step)

            return result

        return wrapper

    return inner

In [None]:
@loop_tqdm(num=10)
def body_fun(i, val):
    return val + i


lax.fori_loop(0, 10, body_fun, 0)

100%|██████████| 10/10 [00:00<00:00, 15087.42it/s]


Array(45, dtype=int32, weak_type=True)

In [None]:
@scan_tqdm(num=10)
def body_fun(carry, x):
    x = x[0]
    return carry + x, carry + x


lax.scan(body_fun, 0, (jnp.arange(10), jnp.arange(10)))

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


100%|██████████| 10/10 [00:00<00:00, 13976.35it/s]


(Array(45, dtype=int32),
 Array([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32))

In [None]:
def host_func(*args, **kwargs):
    print(args)
    print(kwargs)


@eqx.filter_jit
def jax_func():
    io_callback(
        host_func, jnp.empty((1,)), jnp.asarray(1), ordered=False, kw=jnp.asarray(2)
    )
    return jnp.asarray([1])


jax_func()

(array(1, dtype=int32),)
{'kw': array(2, dtype=int32)}


Array([1], dtype=int32)