-
Notifications
You must be signed in to change notification settings - Fork 155
Description
The documentation currently describes how to add custom derivative definitions using DiffRules. However, it seems that this only covers basic custom derivatives. For example, I don't know how this approach would support a black box function (e.g. an external binary) that returns both the function value and the gradient, without having to call it twice. Therefore, I am interested to directly implement a method for Dual
.
Below I show an example of my attempt to create an example. Is this the recommended way to do this? I have put this together after reading source code referred to in this issue. It would be nice to have the case of ℝⁿ → ℝⁿ as well. Perhaps this example can be a starting point to document this way of creating custom derivatives.
using ForwardDiff
module MyModule
using ForwardDiff
using LinearAlgebra
# ℝ → ℝ ——————————————————————————————————————————————
"Original function"
f0(x) = x^2
"""Returns f0(x) and its derivative
In actual usage, this will be a function ForwardDiff cannot differentiate through,
e.g. because it calls an external binary.
"""
fg(x) = (v=f0(x),d=2x)
"test function - calls `fg`"
f(x) = fg(x).v
"Custom derivative for `f` using `fg`"
function f(d::ForwardDiff.Dual{T}) where T
x = ForwardDiff.value(d)
y = fg(x)
ForwardDiff.Dual{T}(y.v,y.d*ForwardDiff.partials(d))
end
# ℝⁿ → ℝ ——————————————————————————————————————————————
f_vs0(x) = x[1]+x[2]^2
fg_vs(x) = (v=f_vs0(x),d=[1.0,2x[2]])
f_vs(x) = fg_vs(x).v
function f_vs(d::Vector{D}) where D<:ForwardDiff.Dual
x = ForwardDiff.value.(d)
y = fg_vs(x)
b_in = zip(collect.(ForwardDiff.partials.(d))...)
b_arr = map(x->y.d⋅x,b_in)
p = ForwardDiff.Partials((b_arr...,))
D(y.v,p)
end
end
## Testing
# ℝ → ℝ
x = 2.3
b = ForwardDiff.derivative(MyModule.f,x)
display(b == ForwardDiff.derivative(MyModule.f0,x))
# ℝⁿ → ℝ
v = [-0.3,3.4]
g = ForwardDiff.gradient(MyModule.f_vs,v)
display(all(g .== ForwardDiff.gradient(MyModule.f_vs0,v)))
## ℝ → ℝⁿ → ℝ
f = x->MyModule.f_vs([-1.2x,-0.5/x])
x2 = 0.7
b2 = ForwardDiff.derivative(f,x2)
f0 = x->MyModule.f_vs0([-1.2x,-0.5/x])
display(b2 == ForwardDiff.derivative(f0,x2))