Skip to content

Target ForwardDiff's public API? #769

@ChrisRackauckas

Description

@ChrisRackauckas

Instead of doing the harder thing of making Duals work in general, why not target the higher level API? I played around a bit, didn't quite get it, but someone might want to take it from here:

using ForwardDiff, Zygote, ZygoteRules, FiniteDiff, Test, Adapt

ZygoteRules.@adjoint function ForwardDiff.derivative(f,x)
    der = ForwardDiff.derivative(f,x)
    function derivative_adjoint(Δ)
        function _f(y)
            out,back = Zygote.pullback(f,y)
            back(Δ)[1]
        end
        (nothing,ForwardDiff.derivative(_f,x))
    end
    der, derivative_adjoint
end

ZygoteRules.@adjoint function ForwardDiff.gradient(f,x)
    grad = ForwardDiff.gradient(f,x)
    function gradient_adjoint(Δ)
        function _f(y)
            out,back = Zygote.pullback(f,y)
            back(Δ)[1]
        end
        (nothing,ForwardDiff.gradient(_f,x))
    end
    grad, gradient_adjoint
end

ZygoteRules.@adjoint function ForwardDiff.jacobian(f,x)
    jac = ForwardDiff.jacobian(f,x)
    function jacobian_adjoint(Δ)
        function _f(y)
            out,back = Zygote.pullback(f,y)
            vec(back(Δ)[1])
        end
        (nothing,ForwardDiff.jacobian(_f,x))
    end
    jac, jacobian_adjoint
end

f(x) = 2x^2 + x
g(x) = ForwardDiff.derivative(f,x)
out,back = Zygote.pullback(g,2.0)
stakehouse = back(1)[1]
@test typeof(stakehouse) <: Float64
@test stakehouse[1] == ForwardDiff.derivative(g,2.0)

f(x) = [2x[1]^2 + x[1],x[2]^2 * x[1]]
g(x) = sum(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])

g(x) = prod(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])

g(x) = sum(abs2,ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])

Metadata

Metadata

Assignees

No one assigned

    Labels

    second orderzygote over zygote, or otherwise

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions