diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1d124772..909a7b08 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,8 +2,6 @@ name: CI on: pull_request: - branches: - - master push: branches: - master diff --git a/src/partials.jl b/src/partials.jl index a5316e3e..be55c1b3 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti @inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values)) @inline Base.:*(x::Real, partials::Partials) = partials*x +@inline function Base.:*(partials::Partials, x::Real) + return Partials(scale_tuple(partials.values, x)) +end + +@inline function Base.:/(partials::Partials, x::Real) + return Partials(div_tuple_by_scalar(partials.values, x)) +end + +@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N + return Partials(mul_tuples(a.values, b.values, x_a, x_b)) +end + @inline function _div_partials(a::Partials, b::Partials, aval, bval) return _mul_partials(a, b, inv(bval), -(aval / (bval*bval))) end @@ -90,33 +102,22 @@ end #----------------------# if NANSAFE_MODE_ENABLED - @inline function Base.:*(partials::Partials, x::Real) - x = ifelse(!isfinite(x) && iszero(partials), one(x), x) - return Partials(scale_tuple(partials.values, x)) - end - - @inline function Base.:/(partials::Partials, x::Real) - x = ifelse(x == zero(x) && iszero(partials), one(x), x) - return Partials(div_tuple_by_scalar(partials.values, x)) + # A dual number with a zero partial is just an unperturbed non-dual number + # Hence when propagated the resulting dual number is unperturbed as well, + # ie., its partial is zero as well, regardless of the primal value + # However, standard floating point multiplication/division would return `NaN` + # if the primal is not-finite/zero + @inline function _mul_partial(partial::Real, x::Real) + y = partial * x + return iszero(partial) ? zero(y) : y end - - @inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N - x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a) - x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b) - return Partials(mul_tuples(a.values, b.values, x_a, x_b)) + @inline function _div_partial(partial::Real, x::Real) + y = partial / x + return iszero(partial) ? zero(y) : y end else - @inline function Base.:*(partials::Partials, x::Real) - return Partials(scale_tuple(partials.values, x)) - end - - @inline function Base.:/(partials::Partials, x::Real) - return Partials(div_tuple_by_scalar(partials.values, x)) - end - - @inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N - return Partials(mul_tuples(a.values, b.values, x_a, x_b)) - end + @inline _mul_partial(partial::Real, x::Real) = partial * x + @inline _div_partial(partial::Real, x::Real) = partial / x end # edge cases where N == 0 # @@ -197,11 +198,11 @@ end end @generated function scale_tuple(tup::NTuple{N}, x) where N - return tupexpr(i -> :(tup[$i] * x), N) + return tupexpr(i -> :(_mul_partial(tup[$i], x)), N) end @generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N - return tupexpr(i -> :(tup[$i] / x), N) + return tupexpr(i -> :(_div_partial(tup[$i], x)), N) end @generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N @@ -217,7 +218,7 @@ end end @generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N - return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N) + return tupexpr(i -> :(_mul_partial(a[$i], afactor) + _mul_partial(b[$i], bfactor)), N) end ################### diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index 4de1a6de..ab5a3631 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -113,4 +113,13 @@ end @test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im) end +@testset "NaN-safe mode" begin + x = ForwardDiff.derivative(log ∘ zero, 1.0) + if ForwardDiff.NANSAFE_MODE_ENABLED + @test iszero(x) + else + @test isnan(x) + end +end + end # module diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 82c2b2a8..9008ce8d 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -148,9 +148,15 @@ end end @testset "exponential function at base zero" begin - @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN]) - @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN]) - @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN]) + if ForwardDiff.NANSAFE_MODE_ENABLED + @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [-Inf, -Inf]) + @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, -Inf]) + @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, 0.0]) + else + @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN]) + @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN]) + @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN]) + end @test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0]) end @@ -207,11 +213,19 @@ end end @testset "gradient for exponential with NaNMath" begin - @test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1]) + if ForwardDiff.NANSAFE_MODE_ENABLED + @test isequal(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[2]), [NaN, 1.0]), [1.0, NaN]) + else + @test isequal(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[2]), [NaN, 1.0]), [NaN, NaN]) + end @test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0] @test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1]) - @test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1]) + if ForwardDiff.NANSAFE_MODE_ENABLED + @test isequal(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0]), [1.0, NaN]) + else + @test isequal(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0]), [NaN, NaN]) + end @test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0] @test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5]) end @@ -286,4 +300,34 @@ end @test grad == SVector{3}(der, der, der) end +@testset "NaN-safe mode" begin + # issue #774 + f = x -> log(zero(x[1]) + x[2]) + x = [1.0, 0.0] + y1 = ForwardDiff.gradient(f, x) + y2 = ForwardDiff.gradient(f, x, ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{1}())) + y3 = ForwardDiff.gradient(f, x, ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{2}())) + for y in (y1, y2, y3) + if ForwardDiff.NANSAFE_MODE_ENABLED + @test y == [0.0, Inf] + else + @test isequal(y, [NaN, Inf]) + end + end + + # issue #745 + g = a -> a[1] * exp(-a[2]) + a = [1.0, -1e3] + b1 = ForwardDiff.gradient(g, a) + b2 = ForwardDiff.gradient(g, a, ForwardDiff.GradientConfig(g, a, ForwardDiff.Chunk{1}())) + b3 = ForwardDiff.gradient(g, a, ForwardDiff.GradientConfig(g, a, ForwardDiff.Chunk{2}())) + for b in (b1, b2, b3) + if ForwardDiff.NANSAFE_MODE_ENABLED + @test b == [Inf, -Inf] + else + @test isequal(b, [NaN, NaN]) + end + end +end + end # module diff --git a/test/PartialsTest.jl b/test/PartialsTest.jl index 8372a53e..23ebff3b 100644 --- a/test/PartialsTest.jl +++ b/test/PartialsTest.jl @@ -111,24 +111,37 @@ samerng() = MersenneTwister(1) @test (PARTIALS / X).values == map(v -> v / X, VALUES) if N > 0 - @test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2)) - @test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2) - @test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS - @test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS + # Only zero partials + ALLZERO = Partials(ntuple(_ -> zero(T), N)) + # Mix of zero and non-zero partials + FIRSTZERO = Partials(ntuple(i -> i == 1 ? zero(T) : rand(T), N)) + + # The following properties should always be satisfied, regardless of whether NaN-safe mode is enabled or disabled + # We use `isequal` for comparisons in the presence of `NaN`s + for p1 in (PARTIALS, ALLZERO, FIRSTZERO), p2 in (PARTIALS2, ALLZERO, FIRSTZERO), v1 in (X, NaN, Inf), v2 in (Y, NaN, Inf) + @test isequal(ForwardDiff._div_partials(p1, p2, v1, v2), ForwardDiff._mul_partials(p1, p2, inv(v2), -v1/(v2^2))) + @test isequal(ForwardDiff._mul_partials(p1, p2, v1, v2), v1 * p1 + v2 * p2) + end + for v1 in (X, NaN, Inf), v2 in (Y, NaN, Inf) + @test isequal(ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, v1, v2), v2 * PARTIALS) + @test isequal(ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, v1, v2), v1 * PARTIALS) + end if ForwardDiff.NANSAFE_MODE_ENABLED - ZEROS = Partials((fill(zero(T), N)...,)) - - @test (NaN * ZEROS).values == ZEROS.values - @test (Inf * ZEROS).values == ZEROS.values - @test (ZEROS / 0).values == ZEROS.values - - @test ForwardDiff._mul_partials(ZEROS, ZEROS, X, NaN).values == ZEROS.values - @test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, X).values == ZEROS.values - @test ForwardDiff._mul_partials(ZEROS, ZEROS, X, Inf).values == ZEROS.values - @test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, X).values == ZEROS.values - @test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, NaN).values == ZEROS.values - @test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, Inf).values == ZEROS.values + for f in ((p -> NaN * p), (p -> Inf * p), (p -> -Inf * p), (p -> p / 0), (p -> p / NaN), (p -> p / Inf), (p -> p / -Inf)) + # Only zero partials + @test iszero(@inferred(f(ALLZERO))) + + # Mix of zero and non-zero partials + z = @inferred(f(FIRSTZERO)) + for i in 1:N + if iszero(FIRSTZERO[i]) + @test iszero(z[i]) + else + @test isequal(z[i], f(FIRSTZERO[i])) + end + end + end end end end