-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Description
Incorrect results for nested AD when the inner AD involves loops. And two runs give different results. See the MWE below.
Environment
- Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900K
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores) - Reactant version: [3c362404] Reactant v0.2.169
- Enzyme version: [7da242da] Enzyme v0.13.85
Minimal Reproducible Example
using Reactant
using Enzyme
function cubic(x)
y = x.^3
return y
end
function vjp_cubic(x, lambdas)
vjps = similar(lambdas)
for i=1:size(lambdas,2)
lambda = lambdas[:,i]
eval_jac_T_v(x) = sum(cubic(x) .* lambda)
vjps[:,i] .= Enzyme.gradient(Reverse, Const(eval_jac_T_v), x)[1]
end
return vjps
end
function jvp_vjp_cubic(v, x, lambdas)
vjp_cubic_inline(x) = vjp_cubic(x, lambdas)
return Enzyme.autodiff(Forward, Const(vjp_cubic_inline), Duplicated(x, v))[1]
end
x = ones(3)
x_r = Reactant.to_rarray(x)
v = ones(3)
v_r = Reactant.to_rarray(x)
lambdas = ones(3,2)
lambdas_r = Reactant.to_rarray(lambdas)
jvp_vjp_cubic(v, x, lambdas)
#= Enzyme (non-compiled) gives correct results
3×2 Matrix{Float64}:
6.0 6.0
6.0 6.0
6.0 6.0
=#
jvp_vjp_cubic_compiled = @compile jvp_vjp_cubic(v_r, x_r, lambdas_r)
# first run
jvp_vjp_cubic_compiled(v_r, x_r, lambdas_r)
#=
3×2 ConcretePJRTArray{Float64,2}:
3.0 6.0
3.0 6.0
3.0 6.0
=#
# second run
jvp_vjp_cubic_compiled(v_r, x_r, lambdas_r)
#=
3×2 ConcretePJRTArray{Float64,2}:
0.0 6.0
0.0 6.0
0.0 6.0
=#
Metadata
Metadata
Assignees
Labels
No labels