Skip to content

Automatic ChainRules compatibility #579

@gdalle

Description

@gdalle

In a recent Slack discussion, @mohamed82008 posted this useful code snippet that shouldn't go to waste.
With a little bit of work, this could be turned into a macro that automatically translates a ChainRulesCore.frule into its ForwardDiff.Dual-compatible counterpart.
Since ChainRulesCore is a very light dependency, would it make sense to include such a thing to ForwardDiff? Judging by the reactions on the Slack #autodiff channel, lots of people would find it useful.

using ChainRulesCore, ForwardDiff

macro ForwardDiff_frule(f)
	quote
		function $(esc(f))(x::Vector{<:ForwardDiff.Dual{T}}) where {T}
		  xv, Δx = ForwardDiff.value.(x), reduce(vcat, transpose.(ForwardDiff.partials.(x)))
		  out, Δf = ChainRulesCore.frule((NoTangent(), Δx), $(esc(f)), xv)
		  if out isa Real
			return ForwardDiff.Dual{T}(out, ForwardDiff.Partials(Tuple(Δf)))
		  elseif out isa Vector
			return ForwardDiff.Dual{T}.(out, ForwardDiff.Partials.(Tuple.(eachrow(Δf))))
		  else
		  	throw("Unsupported output.")
		  end
		end
	end
end

f1(x) = sum(x)
function ChainRulesCore.frule((_, Δx), ::typeof(f1), x::AbstractVector{<:Number})
  println("frule was used")
  return f1(x), sum(Δx, dims = 1)
end

f2(x) = x
function ChainRulesCore.frule((_, Δx), ::typeof(f2), x::AbstractVector{<:Number})
  println("frule was used")
  return f2(x), Δx
end

@ForwardDiff_frule f1
ForwardDiff.gradient(f1, rand(3))
# frule was used
# 3-element Vector{Float64}:
#  1.0
#  1.0
#  1.0

@ForwardDiff_frule f2
ForwardDiff.jacobian(f2, rand(3))
# frule was used
# 3×3 Matrix{Float64}:
#  1.0  0.0  0.0
#  0.0  1.0  0.0
#  0.0  0.0  1.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions