# 04 Debugging

Original Documentation: https://docs.jax.dev/en/latest/debugging.html


In [26]:
import jax
import jax.numpy as jnp

## Print

`jax.debug.print()` can be used for traced array values within `jax.jit`, `jax.vmap`, and others. Python `print()` will not work here because it will only execute at trace-time (and will print out the abstract tracer object types instead).

Python `print()` should be used otherwise for static values such as dtypes and array shapes.


In [27]:
@jax.jit
def f(x):
    jax.debug.print("x={x}", x=x)
    y = jnp.sin(x)
    jax.debug.print("y={y}", y=y)
    return y


f(2.0)

x=2.0
y=0.9092974066734314


Array(0.9092974, dtype=float32, weak_type=True)

Similarly, with `jax.vmap` only `jax.debug.print()` will print the values being mapped over.


In [28]:
@jax.vmap
def f(x):
    jax.debug.print("x={x}", x=x)
    y = jnp.sin(x)
    jax.debug.print("y={y}", y=y)
    return y


xs = jnp.arange(3)
f(xs)

x=0
x=1
x=2
y=0.0
y=0.8414709568023682
y=0.9092974066734314


Array([0.        , 0.84147096, 0.9092974 ], dtype=float32)

Note that all `x` prints occur first, followed by all `y` prints.

Instead, if we want sequential prints, we can use `jax.lax.map()`:


In [29]:
from jax import lax


def f(x):
    jax.debug.print("x={x}", x=x)
    y = jnp.sin(x)
    jax.debug.print("y={y}", y=y)
    return y


xs = jnp.arange(3)
lax.map(f, xs)

y=0.0
x=0
y=0.8414709568023682
x=1
y=0.9092974066734314
x=2


Array([0.        , 0.84147096, 0.9092974 ], dtype=float32)

Notice now that the ordering of the prints is not sequential (i.e., `x` is not printed before `y`). This is a byproduct of JAX deeming that `y` depends on `x`.

If we need the original order, we can add `ordered=True`:


In [30]:
def f(x):
    jax.debug.print("x={x}", x=x, ordered=True)
    y = jnp.sin(x)
    jax.debug.print("y={y}", y=y, ordered=True)
    return y


xs = jnp.arange(3)
lax.map(f, xs)

x=0
y=0.0
x=1
y=0.8414709568023682
x=2
y=0.9092974066734314


Array([0.        , 0.84147096, 0.9092974 ], dtype=float32)

## Breakpoint

Use `jax.debug.breakpoint()` to get PDB-style debugging that will pause program execution.


In [31]:
@jax.jit
def f(x):
    jax.debug.breakpoint()
    y, z = jnp.sin(x), jnp.cos(x)
    jax.debug.breakpoint()
    return y, z


f(2.0)

Entering jdb:
Entering jdb:
Array(2., dtype=float32)
Array(0.9092974, dtype=float32)


(Array(0.9092974, dtype=float32, weak_type=True),
 Array(-0.41614684, dtype=float32, weak_type=True))

For value-dependent breakpoints, we can use `jax.lax.select()`:


In [32]:
def break_if_nonzero(x):
    def true_fun(x):
        pass

    def false_fun(x):
        jax.debug.breakpoint()

    jax.lax.cond(x == 0, true_fun, false_fun, x)


@jax.jit
def f(x):
    break_if_nonzero(x)
    return jnp.sin(x)


f(2.0)  # Non-zero input

Entering jdb:
Array(2., dtype=float32)


Array(0.9092974, dtype=float32, weak_type=True)

## Callbacks

If we need more control, we can use `jax.debug.callback()` which gives greater control over the host-side logic via a Python callback. It is compatible with every transformation (`jax.vmap()`, `jax.jit()`, etc.):


In [33]:
import logging


def log_x(x):
    logging.warning(f"Logging value: {x}")


@jax.jit
def f(x):
    jax.debug.callback(log_x, x)
    return jnp.cos(x)


f(2.0)



Array(-0.41614684, dtype=float32, weak_type=True)

Here’s an example in `jax.vmap()` transformation:


In [34]:
def log_x(x):
    logging.warning(f"Logging value: {x}")


@jax.jit
def f(x):
    jax.debug.callback(log_x, x)
    return jnp.cos(x)


batch_f = jax.vmap(f, in_axes=0)

xs = jnp.stack(jnp.arange(3))
batch_f(xs)



Array([ 1.        ,  0.5403023 , -0.41614684], dtype=float32)