You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The use case is nesting two custom_vjp operators to perform reverse mode AD over reverse mode AD (to calculate an HVP over a function that uses custom_vjp to call into external, untraceable code). I have not used custom_vjp before and the documentation didn't seem to cover anything related to this problem, so please forgive me if I'm missing something obvious here!
When nesting two custom_vjp operations, the cotangents computed by the second custom_vjp do not seem to be propagated properly through the compute graph. A minimal example is below, using the function y=x**3; the computed gradient is 0, which implies that the ddy computed in the second custom_vjp (in op_bck_bck) is never further propagated.
importjax@jax.custom_vjpdefop(x):
returnx**3defop_fwd(x):
# return the output and the saved valuesreturnop(x), (x,)
@jax.custom_vjpdefop_bck(saved, dy):
x, =savedreturn3*x**2*dy, # gradient of y=x^3defop_bck_fwd(saved, dy):
returnop_bck(saved, dy), saveddefop_bck_bck(saved, ddx):
x, =savedddx, =ddxddy=3*x**2*ddx# return:# - dSaved (None as we don't want to diff through saved values)# - ddy (use chain rule)return (None,), ddyop_bck.defvjp(op_bck_fwd, op_bck_bck)
op.defvjp(op_fwd, op_bck)
# test itdefjop(x):
returnx**3xx=jax.random.normal(jax.random.PRNGKey(0), (5,))
ww=jax.random.normal(jax.random.PRNGKey(1), (5,))
ww2=jax.random.normal(jax.random.PRNGKey(2), (5,))
deffn(op, x):
defl(x):
# first vjpreturnjax.grad(lambdaz: op(z) @ ww)(x) @ ww2# should call second vjpreturnjax.grad(l)(x)
gt_jac=fn(jop, xx)
my_jac=fn(op, xx)
print(gt_jac) # [-0.16732852 11.739656 -0.06793058 0.4734605 -0.15350063]print(my_jac) # [0. 0. 0. 0. 0.]
System info (python version, jaxlib version, accelerator, etc.)
Returning None from the second backward function you had doesn't seem quite right. Even though this was invoked with a saved residual, we still want to differentiate. You can see the VJP with respect to this argument as the expression dx = 6 * ... in my op_bck_bck.
In op_bck's custom derivative, it's useful to save the input value dy as a residual, since we need to use it in computing the VJP expression from the previous bullet (dx = 6 * ...).
What do you think? Take a look and see if this looks correct to you. If something remains confusing then we can think on how we'd improve docs somehow, or consider adding an example somewhere (maybe in a future version of the AD cookbook).
Description
The use case is nesting two
custom_vjp
operators to perform reverse mode AD over reverse mode AD (to calculate an HVP over a function that usescustom_vjp
to call into external, untraceable code). I have not usedcustom_vjp
before and the documentation didn't seem to cover anything related to this problem, so please forgive me if I'm missing something obvious here!When nesting two custom_vjp operations, the cotangents computed by the second
custom_vjp
do not seem to be propagated properly through the compute graph. A minimal example is below, using the functiony=x**3
; the computed gradient is 0, which implies that the ddy computed in the secondcustom_vjp
(inop_bck_bck
) is never further propagated.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: