Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Violating assumptions about Dual comparisons #609

Open
gpeairs opened this issue Nov 15, 2022 · 8 comments
Open

Violating assumptions about Dual comparisons #609

gpeairs opened this issue Nov 15, 2022 · 8 comments

Comments

@gpeairs
Copy link

gpeairs commented Nov 15, 2022

There are already a few issues talking about #481 (#606, #607) and related problems (#480). I've been running into breakage hinted at in this comment and this workaround where

julia> x, y = 1, ForwardDiff.Dual(1,1);

julia> x < y || x > y || x == y
false

julia> x >= y && x <= y
true

This seems to be an issue for piecewise functions and wrapper types that assume x <= y is equivalent to x == y || x < y.

For example, something like this previously worked:

import ForwardDiff

struct WrappedNumber{T<:Number}
    value::T
end
for op in (:isless, :(==))
    @eval begin
        Base.$op(x::WrappedNumber, y::WrappedNumber) = ($op)(x.value, y.value)
        Base.$op(x::Number, y::WrappedNumber) = ($op)(x, y.value)
        Base.$op(x::WrappedNumber, y::Number) = ($op)(x.value, y)
    end
end
for op in (:*, :+, :-)
    @eval begin
        Base.$op(x::WrappedNumber, y::WrappedNumber) = WrappedNumber(($op)(x.value, y.value))
        Base.$op(x::Number, y::WrappedNumber) = WrappedNumber(($op)(x, y.value))
        Base.$op(x::WrappedNumber, y::Number) = WrappedNumber(($op)(x.value, y))
    end
end

function f_piecewise(x)
    lb = WrappedNumber(0.0)
    ub = WrappedNumber(1.0)
    (lb <= x < ub) && return 0.5*x*x
    x >= ub && return 0.5*ub*ub + (x - ub)
    return 0.0*x
end

function ForwardDiff.derivative(f, x::WrappedNumber{T}) where T
    r = f(WrappedNumber(ForwardDiff.Dual(x.value, one(T))))
    return ForwardDiff.partials(r.value)[1]
end

And now:

julia> ForwardDiff.derivative(f_piecewise, WrappedNumber(0.99))
0.99

julia> ForwardDiff.derivative(f_piecewise, WrappedNumber(1.01))
1.0

julia> ForwardDiff.derivative(f_piecewise, WrappedNumber(1.0))
0.0 # Used to be 1.0

Maybe I'm not supposed to do this. But it's been very convenient, and I suspect this isn't the only case with problems of this kind. Or maybe the issue is just a matter of following through on the changes to == and isequal for other comparisons?

@mcabbott
Copy link
Member

First, perhaps a cleaner version of the complaint is this:

julia> using ForwardDiff: Dual

julia> function f_3(x; lb = 0.0, ub = 1.0)
           (lb <= x < ub) && return 0.5*x*x
           x >= ub && return 0.5*ub*ub + (x - ub)
           x < lb && return 0.0*x
           missing
       end
f_3 (generic function with 1 method)

julia> f_3.(-0.5:0.5:1.5)
5-element Vector{Float64}:
 -0.0
  0.0
  0.125
  0.5
  1.0

julia> f_3.(Dual.(-0.5:0.5:1.5, 1))
5-element Vector{Dual{Nothing, Float64, 1}}:
 Dual{Nothing}(-0.0,0.0)
  Dual{Nothing}(0.0,0.0)
  Dual{Nothing}(0.125,0.5)
  Dual{Nothing}(0.5,1.0)
  Dual{Nothing}(1.0,1.0)

julia> f_3.(WrappedNumber.(-0.5:0.5:1.5))
5-element Vector{WrappedNumber{Float64}}:
 WrappedNumber{Float64}(-0.0)
 WrappedNumber{Float64}(0.0)
 WrappedNumber{Float64}(0.125)
 WrappedNumber{Float64}(0.5)
 WrappedNumber{Float64}(1.0)

julia> f_3.(WrappedNumber.(Dual.(-0.5:0.5:1.5, 1)))  # ForwardDiff v0.10.33
5-element Vector{Union{Missing, WrappedNumber{Dual{Nothing, Float64, 1}}}}:
 WrappedNumber{Dual{Nothing, Float64, 1}}(Dual{Nothing}(-0.0,0.0))
 missing
 WrappedNumber{Dual{Nothing, Float64, 1}}(Dual{Nothing}(0.125,0.5))
 missing
 WrappedNumber{Dual{Nothing, Float64, 1}}(Dual{Nothing}(1.0,1.0))

This code expects that the conditions exhaust all possibilities, but they do not. Because by overloading only isless and ==, it gets fallback definitions for <= etc:

julia> Dual(1,2) == Dual(1,3)
false

julia> Dual(1,2) <= Dual(1,3)
true

julia> WrappedNumber(Dual(1,2)) <= WrappedNumber(Dual(1,3))
false
# @less points to this definition:   <=(x, y) = (x < y) | (x == y)

Note that there is a fallback definition for == too, which would also produce these result, even on the old ForwardDiff.

That would happen had the code chosen to overload say for op in (:isless, :isequal) instead. Is it somehow canonical to overload exactly (:isless, :(==))?

I don't quite know what to suggest here. I don't see how ForwardDiff using a different definition for <= would change this. Is there a definition for isless which you have in mind?

I know this may not help much, but I never trust myself to write code like the above and not introduce an edge case (which random tests will never find), so perhaps I assumed such expressions were rare in the wild. I would try to write instead single tests, like

function f_4(x; lb = 0.0, ub = 1.0)
    if x < lb
        0.0*x
    elseif x < ub  # and >= lb
        0.5*x*x
    else 
        0.5*ub*ub + (x - ub)
    end
end

@gpeairs
Copy link
Author

gpeairs commented Nov 16, 2022

Thanks! I'm not super familiar with ForwardDiff internals so I don't have a fix on this side in mind. On my side, yeah, I should be able to rewrite all my multiple comparisons to avoid this issue by only using <, and maybe that's better style, anyway. These cases don't exactly jump out at me, though, so I'd worry about missing some.

Is it somehow canonical to overload exactly (:isless, :(==))?

Kind of? My understanding is that

  • isequal falls back to ==, and you mainly want to implement isequal yourself if you're doing something funny with floats or missing
  • == is what you want to implement for types with a notion of equality, and falls back to ===
  • isless is the default comparison used by sort, should give a total order together with isequal, and should also be implemented for numeric types with NaN; it's linked with isequal in that it's expected that exactly one of (isless(x,y), isless(y,x), isequal(x,y)) is true (this is now violated for ForwardDiff.Dual)
  • < falls back to isless but should be implemented if you have a canonical partial order; it's linked with == in that (e.g.) you want to have isequal(NaN, NaN) and isless(1.0, NaN) but not NaN == NaN or 1.0 < NaN (specifically if you have NaN != NaN then you should probably also have !(1.0 < NaN), and vice versa)
  • The other comparisons fall back to the above pretty naturally

The upshot is that you typically want to overload either (:isless, :(==)) or all of (:isless, :(==), :<, :isequal). Anything else and you risk inconsistency, I think.

@mcabbott
Copy link
Member

I don't disagree with these bullet points, but overloading == without hash seems a bit dodgy, and since these are numbers they probably ought to think about NaN.

Whenever I have tried to make a new number type, I have run into never-ending problems (esp. ambiguities against other such types). There's a package which aims to automate this, and I note that it overloads far more things: https://github.com/SimonDanisch/AbstractNumbers.jl/blob/master/src/overloads.jl

@gpeairs
Copy link
Author

gpeairs commented Nov 16, 2022

For sure, my WrappedNumber was just a minimal example. Good to know about AbstractNumbers, too.

@mcabbott
Copy link
Member

Not the same as Dual numbers, but a similar species is provided by Measurements.jl

julia> using Measurements

julia> f_3.((-0.5:0.5:1.5)  0.001)
5-element Vector{Measurement{Float64}}:
  -0.0 ± 0.0
   0.0 ± 0.0
 0.125 ± 0.0005
   0.5 ± 0.001
   1.0 ± 0.001

julia> f_3.((WrappedNumber.(-0.5:0.5:1.5))  0.001)
ERROR: MethodError: no method matching measurement(::WrappedNumber{Float64}, ::Float64)

julia> f_3.(WrappedNumber.((-0.5:0.5:1.5)  0.001))
5-element Vector{Union{Missing, WrappedNumber{Measurement{Float64}}}}:
 WrappedNumber{Measurement{Float64}}(-0.0 ± 0.0)
 missing
 WrappedNumber{Measurement{Float64}}(0.125 ± 0.0005)
 missing
 WrappedNumber{Measurement{Float64}}(1.0 ± 0.001)

None of these functions work with units:

julia> using Unitful

julia> f_3.((-0.5:0.5:1.5) .* 1u"m")
ERROR: DimensionError: 0.0 and -0.5 m are not dimensionally compatible.

julia> f_4.((-0.5:0.5:1.5) .* 1u"m"; lb = 0u"m", ub = 1u"m")
ERROR: DimensionError: 0.5 m² and 0.0 m are not dimensionally compatible.

@gpeairs
Copy link
Author

gpeairs commented Nov 16, 2022

Yeah, Measurements.jl also violates the total-order expectation for isless and isequal, and also overloads <= so it's not < || ==.

It seems where I'm coming down is

  • < can always be a partial order, so if I want to work generically with numeric types, including Dual, I can't assume exactly one of (x < y, x == y, y < x) is true.
  • x <= y doesn't promise that it's equivalent to x < y || x == y. Composite types should probably forward it as x.value <= y etc as AbstractNumber does rather than let it fall back to x.value < y || x.value == y. This is ugly but hard to see a way around. (So e.g. Unitful should probably be patched.)
  • x < y || x >= y is also not necessarily exhaustive, although it happens to be true for Dual and Measurement types (ignoring NaN). If I followed the above bullet, then what I did would work -- but sort of by coincidence.
  • The total-order promise isn't strictly followed outside of Base, so using isless and isequal doesn't solve my problem, not that I'd particularly want to clutter my code with that.

If you agree, then this issue can probably be closed.

@gpeairs
Copy link
Author

gpeairs commented Nov 16, 2022

I think the main alternative would be to have isless(x::Dual, y::Dual) = isless((x.value, x.partials), (y.value, y.partials)), have < fall back on that in the same way it does for floats, and deal with mixed-type comparisons by promotion. Then we have the total order and comparison equivalence we expect, and the difference between isless, isequal and <, == is just the usual deal with NaN. This is more or less what Unitful does, I think.

The drawback is that we're saying x < y when mathematically that might not make sense and might have consequences elsewhere.

@mcabbott
Copy link
Member

I think that the same rule for > was suggested in #377. Want to try it out & look for weird consequences?

#481 tried to change as few things as possible to solve existing issues.

#480 argued that one consequence of changing > would be that gradient(f, x) and -gradient(x -> -f(x), x) may sometimes disagree, e.g. when f contains clamp(x, 0, 1) and we are against the boundary. In that case the comparison is against a non-Dual number... maybe that's necessary for this problem to show up. If it is in fact a problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants