# External Callback

In [1]:
import jax

@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2

result = f(2)

intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [2]:
@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2

result = f(2)

intermediate value: 3


## Types of Callback

### `pure_callback`

In [6]:
import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  # call a numpy (not jax.numpy) operation:
  return np.sin(x).astype(x.dtype)

def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)

x = jnp.arange(5.0)
f(x)

Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [7]:
jax.jit(f)(x)

Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [8]:
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]

Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [9]:
jax.vmap(f)(x)

Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [10]:
jax.grad(f)(x)

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

In [11]:
def print_something():
  print('printing something')
  return np.int32(0)

@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1();

printing something


In [12]:
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2();

### `pure_callback` and Exception

In [13]:
import jax
import jax.numpy as jnp

def raise_via_callback(x):
  def _raise(x):
    raise ValueError(f"value of x is {x}")
  return jax.pure_callback(_raise, x, x)

def raise_if_negative(x):
  return jax.lax.cond(x < 0, raise_via_callback, lambda x: x, x)

x_batch = jnp.arange(4)

[raise_if_negative(x) for x in x_batch]  # does not raise

jax.vmap(raise_if_negative)(x_batch)  # ValueError: value of x is 0

ERROR:jax._src.callback:jax.pure_callback failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/callback.py", line 86, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
                                          ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/callback.py", line 64, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-13-d6cd7cda810f>", line 6, in _raise
    raise ValueError(f"value of x is {x}")
ValueError: value of x is 0


XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>
  File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_instance
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 619, in start
  File "/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py", line 205, in start
  File "/usr/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
  File "/usr/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once
  File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run
  File "/usr/local/lib/python3.11/dist-packages/tornado/ioloop.py", line 699, in <lambda>
  File "/usr/local/lib/python3.11/dist-packages/tornado/ioloop.py", line 750, in _run_callback
  File "/usr/local/lib/python3.11/dist-packages/tornado/gen.py", line 824, in inner
  File "/usr/local/lib/python3.11/dist-packages/tornado/gen.py", line 785, in run
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
  File "/usr/local/lib/python3.11/dist-packages/tornado/gen.py", line 233, in wrapper
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
  File "/usr/local/lib/python3.11/dist-packages/tornado/gen.py", line 233, in wrapper
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
  File "/usr/local/lib/python3.11/dist-packages/tornado/gen.py", line 233, in wrapper
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
  File "<ipython-input-13-d6cd7cda810f>", line 16, in <cell line: 0>
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/api.py", line 1227, in vmap_f
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/linear_util.py", line 193, in call_wrapped
  File "<ipython-input-13-d6cd7cda810f>", line 10, in raise_if_negative
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lax/control_flow/conditionals.py", line 301, in cond
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lax/control_flow/conditionals.py", line 273, in _cond
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lax/control_flow/conditionals.py", line 803, in cond_bind
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 2782, in bind
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/batching.py", line 442, in process_primitive
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lax/control_flow/conditionals.py", line 395, in _cond_batching_rule
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 260, in jaxpr_as_fun
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 509, in eval_jaxpr
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lax/control_flow/loops.py", line 1218, in scan_bind
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 2782, in bind
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 949, in process_primitive
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/dispatch.py", line 88, in apply_primitive
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 332, in cache_miss
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 2782, in bind
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 949, in process_primitive
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 1675, in _pjit_call_impl_python
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/profiler.py", line 333, in wrapper
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/pxla.py", line 1277, in __call__
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2777, in _wrapped_callback
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/callback.py", line 228, in _callback
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/callback.py", line 89, in pure_callback_impl
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/callback.py", line 64, in __call__
  File "<ipython-input-13-d6cd7cda810f>", line 6, in _raise
ValueError: value of x is 0

### `io_callback`

In [14]:
from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x)

generating float32[5]


Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)

In [15]:
jax.vmap(numpy_random_like)(x)

generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]


Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)

In [16]:
@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)

jax.vmap(numpy_random_like_ordered)(x)

ValueError: Cannot `vmap` ordered IO callback.

In [17]:
def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]

generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]


Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)

In [18]:
jax.grad(numpy_random_like)(x)

ValueError: IO callbacks do not support JVP.

In [19]:
@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x

jax.grad(f)(1.0);

hello


### `debug.callback`

In [20]:
from jax import debug

def log_value(x):
  # This could be an actual logging call; we'll use
  # print() for demonstration
  print("log:", x)

@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x

f(1.0);

log: 1.0


In [21]:
x = jnp.arange(5.0)
jax.vmap(f)(x);

log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0


In [22]:
jax.grad(f)(1.0);

log: 1.0


## `pure_callback` with `custom_jvp`

In [27]:
import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # Use vmap_method="broadcast_all" because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z)

In [28]:
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)

In [29]:
print(j1(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


In [30]:
print(jax.jit(j1)(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


In [31]:
print(jax.vmap(j1)(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


In [32]:
jax.grad(j1)(z)

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

In [33]:
jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

In [34]:
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))

-0.06447162


In [35]:
jax.hessian(j1)(2.0)

Array(-0.4003078, dtype=float32, weak_type=True)