Skip to content

Recommended method for creating custom derivatives / gradients using Duals? #413

@sdewaele

Description

@sdewaele

The documentation currently describes how to add custom derivative definitions using DiffRules. However, it seems that this only covers basic custom derivatives. For example, I don't know how this approach would support a black box function (e.g. an external binary) that returns both the function value and the gradient, without having to call it twice. Therefore, I am interested to directly implement a method for Dual.

Below I show an example of my attempt to create an example. Is this the recommended way to do this? I have put this together after reading source code referred to in this issue. It would be nice to have the case of ℝⁿ → ℝⁿ as well. Perhaps this example can be a starting point to document this way of creating custom derivatives.

using ForwardDiff

module MyModule
  using ForwardDiff
  using LinearAlgebra
  # ℝ → ℝ ——————————————————————————————————————————————
  "Original function"
  f0(x) = x^2

  """Returns f0(x) and its derivative
  In actual usage, this will be a function ForwardDiff cannot differentiate through,
  e.g. because it calls an external binary.
  """
  fg(x) = (v=f0(x),d=2x)
  
  "test function - calls `fg`"
  f(x) = fg(x).v

  "Custom derivative for `f` using `fg`"
  function f(d::ForwardDiff.Dual{T}) where T
    x = ForwardDiff.value(d)
    y = fg(x)
    ForwardDiff.Dual{T}(y.v,y.d*ForwardDiff.partials(d))
  end

  # ℝⁿ → ℝ ——————————————————————————————————————————————
  f_vs0(x) = x[1]+x[2]^2
  fg_vs(x) = (v=f_vs0(x),d=[1.0,2x[2]])
  f_vs(x) = fg_vs(x).v
  function f_vs(d::Vector{D}) where D<:ForwardDiff.Dual
    x = ForwardDiff.value.(d)
    y = fg_vs(x)
    b_in = zip(collect.(ForwardDiff.partials.(d))...)
    b_arr = map(x->y.dx,b_in)
    p = ForwardDiff.Partials((b_arr...,))
    D(y.v,p)
  end
end

## Testing

# ℝ → ℝ
x = 2.3
b = ForwardDiff.derivative(MyModule.f,x)
display(b == ForwardDiff.derivative(MyModule.f0,x))

# ℝⁿ → ℝ
v = [-0.3,3.4]
g = ForwardDiff.gradient(MyModule.f_vs,v)
display(all(g .== ForwardDiff.gradient(MyModule.f_vs0,v)))

## ℝ → ℝⁿ → ℝ
f = x->MyModule.f_vs([-1.2x,-0.5/x])
x2 = 0.7
b2 = ForwardDiff.derivative(f,x2)
f0 = x->MyModule.f_vs0([-1.2x,-0.5/x])
display(b2 == ForwardDiff.derivative(f0,x2))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions