# Introduction to debugging
This section introduces you to a set of built-in JAX debugging methods — `jax.debug.print()`, `jax.debug.breakpoint()`, and `jax.debug.callback()` — that you can use with various JAX transformations.

Let’s begin with `jax.debug.print()`.

## `jax.debug.print` for simple inspection
Here is a rule of thumb:

* Use `jax.debug.print()` for traced (dynamic) array values with `jax.jit()`, `jax.vmap()` and others.

* Use Python `print()` for static values, such as dtypes and array shapes.

Recall from Just-in-time compilation that when transforming a function with `jax.jit()`, the Python code is executed with abstract tracers in place of your arrays. Because of this, the Python `print()` function will only print this tracer value:

In [1]:
import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y

result = f(2.)

print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [2]:
@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

result = f(2.)

jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314


In [3]:
def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {}", y)
  return y

xs = jnp.arange(3.)

result = jax.vmap(f)(xs)

jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314


In [4]:
result = jax.lax.map(f, xs)

jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314


In [5]:
def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  return x ** 2

result = jax.grad(f)(1.)

jax.debug.print(x) -> 1.0


In [6]:
@jax.jit
def f(x, y):
  jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
  jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
  return x + y

f(1, 2)

jax.debug.print(x) -> 1
jax.debug.print(y) -> 2


Array(3, dtype=int32, weak_type=True)

## `jax.debug.breakpoint` for `pdb`-like debugging
Summary: Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values.

To pause your compiled JAX program during certain points during debugging, you can use `jax.debug.breakpoint()`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, `jax.debug.breakpoint()` is an application of `jax.debug.callback()` that captures information about the call stack.

To print all available commands during a `breakpoint` debugging session, use the `help` command.

In [7]:
def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  jax.lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z

f(2., 1.) # ==> No breakpoint

Array(2., dtype=float32, weak_type=True)

## `jax.debug.callback` for more control during debugging
Both `jax.debug.print()` and `jax.debug.breakpoint()` are implemented using the more flexible `jax.debug.callback()`, which gives greater control over the host-side logic executed via a Python callback. It is compatible with `jax.jit()`, `jax.vmap()`, `jax.grad()` and other transformations

In [8]:
import logging

def log_value(x):
  logging.warning(f'Logged value: {x}')

@jax.jit
def f(x):
  jax.debug.callback(log_value, x)
  return x

f(1.0);



In [9]:
x = jnp.arange(5.0)
jax.vmap(f)(x);



In [10]:
jax.grad(f)(1.0);



# External callbacks
This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are `jax.pure_callback()`, `jax.experimental.io_callback()` and `jax.debug.callback()`. You can use them even while running under JAX transformations, including `jit()`, `vmap()`, `grad()`.

## Why callbacks?
A callback routine is a way to perform host-side execution of code at runtime. As a simple example, suppose you’d like to print the value of some variable during the course of a computation. Using a simple Python `print()` statement, it looks like this

In [11]:
import jax

@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2

result = f(2)

intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [12]:
@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2

result = f(2)

intermediate value: 3


### Exploring pure_callback
`jax.pure_callback()` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).

The function you pass to `jax.pure_callback()` need not actually be pure, but it will be assumed pure by JAX’s transformations and higher-order functions, which means that it may be silently elided or called multiple times.

In [13]:
import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  # call a numpy (not jax.numpy) operation:
  return np.sin(x).astype(x.dtype)

def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)

x = jnp.arange(5.0)
f(x)

Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [14]:
jax.jit(f)(x)

Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [15]:
jax.vmap(f)(x)

  return jax.pure_callback(f_host, result_shape, x)


Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [16]:
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]

Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

In [17]:
jax.grad(f)(x)

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

In [18]:
def print_something():
  print('printing something')
  return np.int32(0)

@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1();

printing something


In [19]:
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2();

## Exploring io_callback
In contrast to `jax.pure_callback()`, `jax.experimental.io_callback()` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.

As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!).

In [18]:
from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x)

generating float32[5]


Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)

In [19]:
jax.vmap(numpy_random_like)(x)

generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]


Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)

In [20]:
@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)

jax.vmap(numpy_random_like_ordered)(x)

ValueError: Cannot `vmap` ordered IO callback.

In [21]:
def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]

generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]


Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)

In [22]:
jax.grad(numpy_random_like)(x)

ValueError: IO callbacks do not support JVP.

In [23]:
@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x

jax.grad(f)(1.0);

hello


## Exploring debug.callback
Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they’re calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes nothing about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` cannot return any value to the program.

In [24]:
from jax import debug

def log_value(x):
  # This could be an actual logging call; we'll use
  # print() for demonstration
  print("log:", x)

@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x

f(1.0);

log: 1.0


In [25]:
x = jnp.arange(5.0)
jax.vmap(f)(x);

log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0


In [26]:
jax.grad(f)(1.0);

log: 1.0


In [27]:
import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # You use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

In [28]:
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)

In [29]:
print(j1(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)


In [30]:
print(jax.jit(j1)(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)


In [31]:
print(jax.vmap(j1)(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)


In [32]:
jax.grad(j1)(z)

  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)


ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

In [33]:
jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))

-0.06447162


  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)


In [34]:
jax.hessian(j1)(2.0)

  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)


Array(-0.4003078, dtype=float32, weak_type=True)