You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
using Zygote, ForwardDiff, SciMLSensitivity, SciMLBase, LinearSolve, ComponentArrays,
FiniteDiff
functionloss_function(θ)
(; A, b) = θ
prob =LinearProblem(A, b)
sol =solve(prob, nothing)
returnsum(sol.u)
endfunctionloss_function_chainrules(θ)
(; A, b) = θ
x = A \ b
returnsum(x)
end
A = Float32[10; 1-2]; b = Float32[32; -4];
θ =ComponentArray(; A, b)
loss_function(θ) ≈loss_function_chainrules(θ) # true
Zygote.gradient(loss_function, θ) # fails
Zygote.gradient(loss_function_chainrules, θ) # works
ForwardDiff.gradient(loss_function, θ) # fails
ForwardDiff.gradient(loss_function_chainrules, θ) # works
FiniteDiff.finite_difference_gradient(loss_function, θ) # works
Simple non-working example.
We need SciML/LinearSolve.jl#322 before we can work on the adjoint code.
The text was updated successfully, but these errors were encountered: