Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ name: CI

on:
pull_request:
branches:
- master
push:
branches:
- master
Expand Down
55 changes: 28 additions & 27 deletions src/partials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti
@inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values))
@inline Base.:*(x::Real, partials::Partials) = partials*x

@inline function Base.:*(partials::Partials, x::Real)
return Partials(scale_tuple(partials.values, x))
end

@inline function Base.:/(partials::Partials, x::Real)
return Partials(div_tuple_by_scalar(partials.values, x))
end

@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
end

@inline function _div_partials(a::Partials, b::Partials, aval, bval)
return _mul_partials(a, b, inv(bval), -(aval / (bval*bval)))
end
Expand All @@ -90,33 +102,22 @@ end
#----------------------#

if NANSAFE_MODE_ENABLED
@inline function Base.:*(partials::Partials, x::Real)
x = ifelse(!isfinite(x) && iszero(partials), one(x), x)
return Partials(scale_tuple(partials.values, x))
end

@inline function Base.:/(partials::Partials, x::Real)
x = ifelse(x == zero(x) && iszero(partials), one(x), x)
return Partials(div_tuple_by_scalar(partials.values, x))
# A dual number with a zero partial is just an unperturbed non-dual number
# Hence when propagated the resulting dual number is unperturbed as well,
# ie., its partial is zero as well, regardless of the primal value
# However, standard floating point multiplication/division would return `NaN`
# if the primal is not-finite/zero
@inline function _mul_partial(partial::Real, x::Real)
y = partial * x
return iszero(partial) ? zero(y) : y
end

@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a)
x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b)
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
@inline function _div_partial(partial::Real, x::Real)
y = partial / x
return iszero(partial) ? zero(y) : y
end
else
@inline function Base.:*(partials::Partials, x::Real)
return Partials(scale_tuple(partials.values, x))
end

@inline function Base.:/(partials::Partials, x::Real)
return Partials(div_tuple_by_scalar(partials.values, x))
end

@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
end
@inline _mul_partial(partial::Real, x::Real) = partial * x
@inline _div_partial(partial::Real, x::Real) = partial / x
end

# edge cases where N == 0 #
Expand Down Expand Up @@ -197,11 +198,11 @@ end
end

@generated function scale_tuple(tup::NTuple{N}, x) where N
return tupexpr(i -> :(tup[$i] * x), N)
return tupexpr(i -> :(_mul_partial(tup[$i], x)), N)
end

@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
return tupexpr(i -> :(tup[$i] / x), N)
return tupexpr(i -> :(_div_partial(tup[$i], x)), N)
end

@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
Expand All @@ -217,7 +218,7 @@ end
end

@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
return tupexpr(i -> :(_mul_partial(a[$i], afactor) + _mul_partial(b[$i], bfactor)), N)
end

###################
Expand Down
9 changes: 9 additions & 0 deletions test/DerivativeTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,13 @@ end
@test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im)
end

@testset "NaN-safe mode" begin
x = ForwardDiff.derivative(log ∘ zero, 1.0)
if ForwardDiff.NANSAFE_MODE_ENABLED
@test iszero(x)
else
@test isnan(x)
end
end

end # module
54 changes: 49 additions & 5 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,15 @@ end
end

@testset "exponential function at base zero" begin
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN])
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN])
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN])
if ForwardDiff.NANSAFE_MODE_ENABLED
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [-Inf, -Inf])
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, -Inf])
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, 0.0])
else
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN])
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN])
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN])
end
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
end

Expand Down Expand Up @@ -207,11 +213,19 @@ end
end

@testset "gradient for exponential with NaNMath" begin
@test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1])
if ForwardDiff.NANSAFE_MODE_ENABLED
@test isequal(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[2]), [NaN, 1.0]), [1.0, NaN])
else
@test isequal(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[2]), [NaN, 1.0]), [NaN, NaN])
end
@test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0]
@test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1])

@test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1])
if ForwardDiff.NANSAFE_MODE_ENABLED
@test isequal(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0]), [1.0, NaN])
else
@test isequal(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0]), [NaN, NaN])
end
@test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0]
@test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5])
end
Expand Down Expand Up @@ -286,4 +300,34 @@ end
@test grad == SVector{3}(der, der, der)
end

@testset "NaN-safe mode" begin
# issue #774
f = x -> log(zero(x[1]) + x[2])
x = [1.0, 0.0]
y1 = ForwardDiff.gradient(f, x)
y2 = ForwardDiff.gradient(f, x, ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{1}()))
y3 = ForwardDiff.gradient(f, x, ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{2}()))
for y in (y1, y2, y3)
if ForwardDiff.NANSAFE_MODE_ENABLED
@test y == [0.0, Inf]
else
@test isequal(y, [NaN, Inf])
end
end

# issue #745
g = a -> a[1] * exp(-a[2])
a = [1.0, -1e3]
b1 = ForwardDiff.gradient(g, a)
b2 = ForwardDiff.gradient(g, a, ForwardDiff.GradientConfig(g, a, ForwardDiff.Chunk{1}()))
b3 = ForwardDiff.gradient(g, a, ForwardDiff.GradientConfig(g, a, ForwardDiff.Chunk{2}()))
for b in (b1, b2, b3)
if ForwardDiff.NANSAFE_MODE_ENABLED
@test b == [Inf, -Inf]
else
@test isequal(b, [NaN, NaN])
end
end
end

end # module
45 changes: 29 additions & 16 deletions test/PartialsTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,37 @@ samerng() = MersenneTwister(1)
@test (PARTIALS / X).values == map(v -> v / X, VALUES)

if N > 0
@test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2))
@test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)
@test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS
@test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS
# Only zero partials
ALLZERO = Partials(ntuple(_ -> zero(T), N))
# Mix of zero and non-zero partials
FIRSTZERO = Partials(ntuple(i -> i == 1 ? zero(T) : rand(T), N))

# The following properties should always be satisfied, regardless of whether NaN-safe mode is enabled or disabled
# We use `isequal` for comparisons in the presence of `NaN`s
for p1 in (PARTIALS, ALLZERO, FIRSTZERO), p2 in (PARTIALS2, ALLZERO, FIRSTZERO), v1 in (X, NaN, Inf), v2 in (Y, NaN, Inf)
@test isequal(ForwardDiff._div_partials(p1, p2, v1, v2), ForwardDiff._mul_partials(p1, p2, inv(v2), -v1/(v2^2)))
@test isequal(ForwardDiff._mul_partials(p1, p2, v1, v2), v1 * p1 + v2 * p2)
end
for v1 in (X, NaN, Inf), v2 in (Y, NaN, Inf)
@test isequal(ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, v1, v2), v2 * PARTIALS)
@test isequal(ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, v1, v2), v1 * PARTIALS)
end

if ForwardDiff.NANSAFE_MODE_ENABLED
ZEROS = Partials((fill(zero(T), N)...,))

@test (NaN * ZEROS).values == ZEROS.values
@test (Inf * ZEROS).values == ZEROS.values
@test (ZEROS / 0).values == ZEROS.values

@test ForwardDiff._mul_partials(ZEROS, ZEROS, X, NaN).values == ZEROS.values
@test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, X).values == ZEROS.values
@test ForwardDiff._mul_partials(ZEROS, ZEROS, X, Inf).values == ZEROS.values
@test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, X).values == ZEROS.values
@test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, NaN).values == ZEROS.values
@test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, Inf).values == ZEROS.values
for f in ((p -> NaN * p), (p -> Inf * p), (p -> -Inf * p), (p -> p / 0), (p -> p / NaN), (p -> p / Inf), (p -> p / -Inf))
# Only zero partials
@test iszero(@inferred(f(ALLZERO)))

# Mix of zero and non-zero partials
z = @inferred(f(FIRSTZERO))
for i in 1:N
if iszero(FIRSTZERO[i])
@test iszero(z[i])
else
@test isequal(z[i], f(FIRSTZERO[i]))
end
end
end
end
end
end
Expand Down