In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

Use the jax.checkpoint() decorator (aliased as jax.remat()) with jax.grad() to control which intermediates are saved on the forward pass versus the recomputed intermediates on the backward pass, trading off memory and FLOPs.

In [2]:
import jax
import jax.numpy as jnp

In [5]:
import jax.ad_checkpoint


def g(W, x):
    y = jnp.dot(W, x)
    return jnp.sin(y)

def f(W1, W2, W3, x):
    x = g(W1, x)
    x = g(W2, x)
    x = g(W3, x)
    return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_3842389/194814268.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_3842389/194814268.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_3842389/194814268.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_3842389/194814268.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_3842389/194814268.py:6 (g)


In [7]:
import jax.ad_checkpoint


def g(W, x):
    y = jnp.dot(W, x)
    return jnp.sin(y)

def f(W1, W2, W3, x):
    x = jax.checkpoint(g)(W1, x)
    x = jax.checkpoint(g)(W2, x)
    x = jax.checkpoint(g)(W3, x)
    return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_3842389/2411689849.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_3842389/2411689849.py:6 (g)


In [8]:
def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z


In other words, you apply jax.checkpoint() to g — the first stage of f — rather than to f itself. This way, when you evaluate jax.grad(f_checkpoint)(x), you’d get a computation like:

    Run the forward pass of g, discarding residual values.

    Run the forward pass of h, saving residuals.

    Run the backward pass of h, consuming residuals from step 2.

    Re-run the forward pass of g, saving residuals.

    Run the backward pass of g, consuming residuals from step 4.


In general, jax.checkpoint(foo) is a new function which has the same input-output behavior as foo, but behaves differently under autodiff, particularly under jax.linearize() and jax.vjp() (and their wrappers, like jax.grad()) but not jax.jvp(). When differentiated, only the input to a jax.checkpoint()-differentiated function is stored on the forward pass. On the backward pass, the residuals (intermediates from foo and its Jacobian coefficient values needed for the backward pass) are recomputed.

Basically, when you use jax.checkpoint and reverse autodiff, during the forward pass, only the input to foo, x is stored. When you run jax.grad or jax.vjp, then the intermediates of foo are recomputed.



    Without jax.checkpoint(), JAX’s autodiff tends to compute everything possible on the forward pass and store it for the backward pass.

    With a jax.checkpoint() decorator, you instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.


You can even define policies on what to save and what to recompute. You can also define a policy wherein certain computations are offloaded to cpu memory once the computation is complete.

Note that XLA does this save vs recompute automatically when using jax.jit.