Hello,
I needed to compute the gradient over a gradient and was able to run the following code without error using StatefulLuxLayer as described in the "Nested Automatic Differentiation" chapter of the documentation. However, when I verified the result with FiniteDiff.jl, as the documentation does, I found that the obtained result is invalid.
Code Example:
using Lux, Random, LinearAlgebra, Zygote, FiniteDiff
rng = Xoshiro(0)
model = Dense(3 => 3) # very simple model
ps, st = Lux.setup(rng, model)
x = ones32(rng, 3, 1)
grads_1 = Zygote.gradient(x) do x
smodel = StatefulLuxLayer(model, ps, st)
v = randn!(rng, similar(x))
w = Zygote.gradient((z -> dot(v, z)) ∘ smodel, x) |> only
return sum(w)
end |> only # shows [0.0; 0.0; 0.0;;]
grads_2 = FiniteDiff.finite_difference_gradient(x) do x
smodel = StatefulLuxLayer(model, ps, st)
v = randn!(rng, similar(x))
w = Zygote.gradient((z -> dot(v, z)) ∘ smodel, x) |> only
return sum(w)
end # shows [-260.231; 114.75144; 21.052755;;]
Is this the expected behavior of this package? I would be glad to know if there is anything in the code that needs to be corrected.
Thank you.