In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
from jax import make_jaxpr

# --- Case A: kappa as a Python float (STATIC) ---
kappa_py = 2.0

@eqx.filter_jit
def f_py(x):
    # kappa_py is captured from the closure.
    return kappa_py * jnp.sum(x**2)

x = jnp.arange(4., dtype=jnp.float32)
print("=== JAXPR with kappa as PY float (static) ===")
print(make_jaxpr(f_py)(x))   # look at invars vs consts


# --- Case B: kappa as a 0-D array (DYNAMIC) ---
kappa_arr = jnp.array(2.0, dtype=jnp.float32)

@eqx.filter_jit
def f_arr(x):
    # kappa_arr is captured but it's an ARRAY leaf → dynamic.
    return kappa_arr * jnp.sum(x**2)

print("\n=== JAXPR with kappa as 0-D ARRAY (dynamic) ===")
print(make_jaxpr(f_arr)(x))


=== JAXPR with kappa as PY float (static) ===
{ [34;1mlambda [39;22m; a[35m:f32[4][39m. [34;1mlet
    [39;22m_[35m:i32[][39m b[35m:f32[][39m = pjit[
      name=f_py
      jaxpr={ [34;1mlambda [39;22m; a[35m:f32[4][39m. [34;1mlet
          [39;22mc[35m:f32[4][39m = integer_pow[y=2] a
          d[35m:f32[][39m = reduce_sum[axes=(0,)] c
          b[35m:f32[][39m = mul 2.0:f32[] d
        [34;1min [39;22m(0:i32[], b) }
    ] a
  [34;1min [39;22m(b,) }

=== JAXPR with kappa as 0-D ARRAY (dynamic) ===
{ [34;1mlambda [39;22m; a[35m:f32[4][39m. [34;1mlet
    [39;22m_[35m:i32[][39m b[35m:f32[][39m = pjit[
      name=f_arr
      jaxpr={ [34;1mlambda [39;22mc[35m:f32[][39m; a[35m:f32[4][39m. [34;1mlet
          [39;22md[35m:f32[4][39m = integer_pow[y=2] a
          e[35m:f32[][39m = reduce_sum[axes=(0,)] d
          b[35m:f32[][39m = mul c e
        [34;1min [39;22m(0:i32[], b) }
    ] a
  [34;1min [39;22m(b,) }


In [7]:
import jax
from jax import make_jaxpr
import jax.numpy as jnp
import equinox as eqx
from equinox import filter_make_jaxpr

In [9]:
# A PyTree argument with mixed leaves
arg = {
    "W": jnp.ones((2,2), dtype=jnp.float32),  # array → dynamic
    "name": "layer",                           # str   → static
    "use_bias": False,                          # bool  → static
}

@eqx.filter_jit
def forward(mod, x):
    y = mod["W"] @ x
    return y + 1 if mod["use_bias"] else y

x = jnp.ones((2,), dtype=jnp.float32)
print("\n=== JAXPR for mixed PyTree argument (arrays dynamic; others static) ===")
print(filter_make_jaxpr(forward)(arg, x))



=== JAXPR for mixed PyTree argument (arrays dynamic; others static) ===
({ [34;1mlambda [39;22m; a[35m:f32[2,2][39m b[35m:f32[2][39m. [34;1mlet
    [39;22m_[35m:i32[][39m c[35m:f32[2][39m = pjit[
      name=forward
      jaxpr={ [34;1mlambda [39;22m; a[35m:f32[2,2][39m b[35m:f32[2][39m. [34;1mlet
          [39;22mc[35m:f32[2][39m = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] a b
        [34;1min [39;22m(0:i32[], c) }
    ] a b
  [34;1min [39;22m(c,) }, ShapeDtypeStruct(shape=(2,), dtype=float32), None)


In [14]:
from typing import Union

In [16]:
class Tiny(eqx.Module):
    kappa: Union[float, jax.Array]        # dynamic (array leaf)
    scheme: str             # static (python)
    def __init__(self):
        self.kappa = 1.5 #jnp.array(1.5, dtype=jnp.float32)
        self.scheme = "fourier"

    def rhs(self, u, t):
        # both captured through 'self'
        return self.kappa * (u + t)

model = Tiny()

@eqx.filter_jit
def rhs_wrapped(u, t):
    return model.rhs(u, t)

u = jnp.ones((3,), dtype=jnp.float32)
t = jnp.array(0.1, dtype=jnp.float32)

print("\n=== JAXPR class capture: array field dynamic; python field static ===")
print(make_jaxpr(rhs_wrapped)(u, t))



=== JAXPR class capture: array field dynamic; python field static ===
{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[][39m. [34;1mlet
    [39;22m_[35m:i32[][39m c[35m:f32[3][39m = pjit[
      name=rhs_wrapped
      jaxpr={ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[][39m. [34;1mlet
          [39;22md[35m:f32[3][39m = add a b
          c[35m:f32[3][39m = mul 1.5:f32[] d
        [34;1min [39;22m(0:i32[], c) }
    ] a b
  [34;1min [39;22m(c,) }
