-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
second orderzygote over zygote, or otherwisezygote over zygote, or otherwise
Description
Instead of doing the harder thing of making Duals work in general, why not target the higher level API? I played around a bit, didn't quite get it, but someone might want to take it from here:
using ForwardDiff, Zygote, ZygoteRules, FiniteDiff, Test, Adapt
ZygoteRules.@adjoint function ForwardDiff.derivative(f,x)
der = ForwardDiff.derivative(f,x)
function derivative_adjoint(Δ)
function _f(y)
out,back = Zygote.pullback(f,y)
back(Δ)[1]
end
(nothing,ForwardDiff.derivative(_f,x))
end
der, derivative_adjoint
end
ZygoteRules.@adjoint function ForwardDiff.gradient(f,x)
grad = ForwardDiff.gradient(f,x)
function gradient_adjoint(Δ)
function _f(y)
out,back = Zygote.pullback(f,y)
back(Δ)[1]
end
(nothing,ForwardDiff.gradient(_f,x))
end
grad, gradient_adjoint
end
ZygoteRules.@adjoint function ForwardDiff.jacobian(f,x)
jac = ForwardDiff.jacobian(f,x)
function jacobian_adjoint(Δ)
function _f(y)
out,back = Zygote.pullback(f,y)
vec(back(Δ)[1])
end
(nothing,ForwardDiff.jacobian(_f,x))
end
jac, jacobian_adjoint
end
f(x) = 2x^2 + x
g(x) = ForwardDiff.derivative(f,x)
out,back = Zygote.pullback(g,2.0)
stakehouse = back(1)[1]
@test typeof(stakehouse) <: Float64
@test stakehouse[1] == ForwardDiff.derivative(g,2.0)
f(x) = [2x[1]^2 + x[1],x[2]^2 * x[1]]
g(x) = sum(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])
g(x) = prod(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])
g(x) = sum(abs2,ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])Metadata
Metadata
Assignees
Labels
second orderzygote over zygote, or otherwisezygote over zygote, or otherwise