<a href="https://colab.research.google.com/github/PWhiddy/jax-experiments/blob/main/while_loop_exit_not_optimized_by_jit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
import jax
import jax.numpy as jnp
from functools import partial

def dynamic_sum_loop(x):
    def cond_fun(state):
        i, total, target, orig_x = state
        return i < target

    def body_fun(state):
        i, total, target, orig_x = state
        return i + 1, total + i, target, orig_x

    # Initial state setup
    init_i = 0
    init_total = 0
    # If x is 0, sum up to "10000", else a couple iterations to make just "1"
    target = jnp.where(x == 0, 10_000, 2)

    # Run the while loop
    final_i, final_total, _, orig_x = jax.lax.while_loop(
        cond_fun,
        body_fun,
        (init_i, init_total, target, x)
    )

    return final_total

x = jnp.arange(100_000)

vmap_dynamic_sum = jax.vmap(dynamic_sum_loop)

jitted_vmap_dynamic_sum = jax.jit(vmap_dynamic_sum)

result = jitted_vmap_dynamic_sum(x)

print("First 5 results:", result[:5])
print("Result shape:", result.shape)

First 5 results: [49995000        1        1        1        1]
Result shape: (100000,)


In [23]:
dynamic_sum_loop(jnp.array(0))

Array(49995000, dtype=int32, weak_type=True)

In [24]:
sum([x for x in range(10000)])

49995000

In [25]:
jitted_vmap_dynamic_sum(jnp.array([0,1,2]))

Array([49995000,        1,        1], dtype=int32, weak_type=True)

In [26]:
input_no_work = jnp.ones(100_000)
input_all_work = jnp.zeros(100_000)
input_single_work = jnp.arange(100_000)

In [27]:
jitted_vmap_dynamic_sum(input_single_work)

Array([49995000,        1,        1, ...,        1,        1,        1],      dtype=int32, weak_type=True)

In [28]:
%%timeit
jitted_vmap_dynamic_sum(input_no_work)

173 µs ± 5.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [29]:
%%timeit
jitted_vmap_dynamic_sum(input_all_work)

238 ms ± 21.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [30]:
%%timeit
jitted_vmap_dynamic_sum(input_single_work)

239 ms ± 24.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
