-
Notifications
You must be signed in to change notification settings - Fork 154
Closed
Description
In a recent Slack discussion, @mohamed82008 posted this useful code snippet that shouldn't go to waste.
With a little bit of work, this could be turned into a macro that automatically translates a ChainRulesCore.frule into its ForwardDiff.Dual-compatible counterpart.
Since ChainRulesCore is a very light dependency, would it make sense to include such a thing to ForwardDiff? Judging by the reactions on the Slack #autodiff channel, lots of people would find it useful.
using ChainRulesCore, ForwardDiff
macro ForwardDiff_frule(f)
quote
function $(esc(f))(x::Vector{<:ForwardDiff.Dual{T}}) where {T}
xv, Δx = ForwardDiff.value.(x), reduce(vcat, transpose.(ForwardDiff.partials.(x)))
out, Δf = ChainRulesCore.frule((NoTangent(), Δx), $(esc(f)), xv)
if out isa Real
return ForwardDiff.Dual{T}(out, ForwardDiff.Partials(Tuple(Δf)))
elseif out isa Vector
return ForwardDiff.Dual{T}.(out, ForwardDiff.Partials.(Tuple.(eachrow(Δf))))
else
throw("Unsupported output.")
end
end
end
end
f1(x) = sum(x)
function ChainRulesCore.frule((_, Δx), ::typeof(f1), x::AbstractVector{<:Number})
println("frule was used")
return f1(x), sum(Δx, dims = 1)
end
f2(x) = x
function ChainRulesCore.frule((_, Δx), ::typeof(f2), x::AbstractVector{<:Number})
println("frule was used")
return f2(x), Δx
end
@ForwardDiff_frule f1
ForwardDiff.gradient(f1, rand(3))
# frule was used
# 3-element Vector{Float64}:
# 1.0
# 1.0
# 1.0
@ForwardDiff_frule f2
ForwardDiff.jacobian(f2, rand(3))
# frule was used
# 3×3 Matrix{Float64}:
# 1.0 0.0 0.0
# 0.0 1.0 0.0
# 0.0 0.0 1.0mohdibntarek, doddgray, antoine-levitt, yebai, oxinabox and 5 more
Metadata
Metadata
Assignees
Labels
No labels