Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update abs diff rule to 0 at non-differentiable point #98

Merged
merged 1 commit into from
Jun 2, 2023
Merged

Update abs diff rule to 0 at non-differentiable point #98

merged 1 commit into from
Jun 2, 2023

Conversation

agerlach
Copy link
Contributor

@agerlach agerlach commented Jun 1, 2023

This PR updates the diffrule for abs to return 0 at the non-differentiable point. The current implementation returns 1. Although valid, this can prevent convergence in gradient descent. The implementation in this PR is the behavior the ChainRules.jl docs advises.

This also comes with the added benefit of not requiring the type to support the ternary operator such as IntervalArithmetic.Interval. This is the use case that led me to make this PR.

using IntervalArithmetic, ForwardDiff

ForwardDiff.derivative(abs, -2.0 .. 2.0)
ERROR: TypeError: non-boolean (Interval{Float64}) used in boolean context
Stacktrace:
 [1] _abs_deriv(x::Interval{Float64})
   @ DiffRules ~/.julia/packages/DiffRules/wKSai/src/rules.jl:73
 [2] abs
   @ ~/.julia/packages/ForwardDiff/vXysl/src/dual.jl:240 [inlined]

With this PR:

ForwardDiff.derivative(abs, -2.0 .. 2.0) # [-1, 1]
ForwardDiff.derivative(abs, 0.0 .. 2.0)  # [0, 1]
ForwardDiff.derivative(abs, -3.0 .. 1.0) # [ -1, -1]

The diffrule for abs has the following comment, which I'm not sure how to interpret. As it doesn't work with IntervalArithmetic.Interval or Intervals.Intervel. Additionally, the current definition assumes that 0 is not in the interval.

# We provide this hook for special number types like `Interval`

@agerlach agerlach changed the title Update abs diff rule Update abs diff rule to 0 at non-differentiable point Jun 1, 2023
@oxinabox
Copy link
Member

oxinabox commented Jun 1, 2023

Tracker.jl breakage is unrelated.

I believe this is correct. (But of course I do, I am explicitly proponents of this property.)

I will merge this tomorrow unless someone raises good objections.

@agerlach
Copy link
Contributor Author

agerlach commented Jun 1, 2023

Re: Tracker.jl I was hoping that was the case. Thanks

@codecov
Copy link

codecov bot commented Jun 1, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (2001650) 97.86% compared to head (fee3857) 97.86%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master      #98   +/-   ##
=======================================
  Coverage   97.86%   97.86%           
=======================================
  Files           3        3           
  Lines         187      187           
=======================================
  Hits          183      183           
  Misses          4        4           
Impacted Files Coverage Δ
src/rules.jl 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@devmotion
Copy link
Member

The diffrule for abs has the following comment, which I'm not sure how to interpret.

git blame shows it was added in #33 and there the explanation is arguably a bit clearer: DiffRules._abs_deriv is intended as a hook that downstream packages such as the ones for interval arithmetic can overload. Based on a JuliaHub search (https://juliahub.com/ui/Search?q=_abs_deriv&type=code) it seems that no public package actually overloads it (anymore).

@devmotion
Copy link
Member

Some additional historical context: It seems the rule for abs was originally added in #11, and there it was suggested to use signbit since at that time it was used in abs(::ForwardDiff.Dual) (which was later removed in JuliaDiff/ForwardDiff.jl#311).

@oxinabox oxinabox merged commit 2cf092b into JuliaDiff:master Jun 2, 2023
11 of 13 checks passed
@agerlach
Copy link
Contributor Author

agerlach commented Jun 2, 2023

@devmotion Thanks for the extra context.

@andreasnoack
Copy link
Member

andreasnoack commented Jun 2, 2023

I think we should revert this. It breaks higher order derivatives for some differentiable functions. E.g.

julia> ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])
1×1 Matrix{Float64}:
 2.0

(TestDiffRules) pkg> add DiffRules@1.14
   Resolving package versions...
    Updating `~/TestDiffRules/Project.toml`
  [b552c78f]  DiffRules v1.13.0  v1.14.0
    Updating `~/TestDiffRules/Manifest.toml`
  [b552c78f]  DiffRules v1.13.0  v1.14.0

julia> ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])
1×1 Matrix{Float64}:
 0.0

The example here is of course trivial but abs is used in many places for numerical reasons in differentiable functions such as https://github.com/JuliaStats/Distributions.jl/blob/2dee35e13eacb0909c6b2189f229ce93c04d2560/src/univariate/continuous/logistic.jl#L82.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants