Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.7"
version = "0.12.8"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
5 changes: 4 additions & 1 deletion src/difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ difference(::Real, ::T, ::T) where {T<:Integer} = NoTangent()

difference(ε::Real, y::T, x::T) where {T<:Number} = (y - x) / ε

difference(ε::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x)
# we are a bit more relaced for AbstractArrays as they naturally represent a vector space
difference(ε::Real, y::AbstractArray, x) = difference.(ε, y, x)
# resolve ambiguity
difference(ε::Real, y::T, x::T) where {T<:AbstractArray} = difference.(ε, y, x)

function difference(ε::Real, y::T, x::T) where {T<:Tuple}
return Tangent{T}(difference.(ε, y, x)...)
Expand Down
4 changes: 3 additions & 1 deletion src/rand_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()

rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)

# TODO: right now Julia don't allow `randn(rng, BigFloat)`
# TODO: right now Julia don't allow `randn(rng, BigFloat)`
# see: https://github.com/JuliaLang/julia/issues/17629
rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng))

rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x)))
rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x)))

function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
return Tangent{T}(rand_tangent.(Ref(rng), x)...)
Expand Down
20 changes: 11 additions & 9 deletions test/difference.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
using FiniteDifferences: rand_tangent, difference

function test_difference(ε::Real, x, dx)
y = x + ε * dx
dx_diff = difference(ε, y, x)
# TODO: `@test isapprox(dx, dx_diff)` once `isapprox` is defined appropriately
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/184
@test typeof(dx) == typeof(dx_diff)
end

@testset "difference" begin

@testset "Primal: $(typeof(x))" for (ε, x) in [
Expand Down Expand Up @@ -56,7 +48,17 @@ end
(randn(), Adjoint(randn(ComplexF64, 3, 3))),
(randn(), Transpose(randn(3))),
]
test_difference(ε, x, rand_tangent(x))
# Construct a value that should be equal to the difference and check that it is
dx = rand_tangent(x)
y = x + ε * dx
dx_diff = difference(ε, y, x)

if x isa AbstractArray{<:Number} || x isa Number
@test x + dx ≈ x + dx_diff
else
# hard to check value if don't overload `≈` so for now we just check type
@test typeof(dx) == typeof(dx_diff)
end
end

# Ensure struct fallback errors for non-struct types.
Expand Down
26 changes: 15 additions & 11 deletions test/rand_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ using FiniteDifferences: rand_tangent
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
([randn(5, 4), 4.0], Vector{Any}),

# Wrapper Arrays
(randn(5, 4)', Adjoint{Float64, Matrix{Float64}}),
(transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}),


# Tuples.
((4.0, ), Tangent{Tuple{Float64}}),
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),
Expand Down Expand Up @@ -66,20 +71,19 @@ using FiniteDifferences: rand_tangent
Hermitian(randn(ComplexF64, 1, 1)),
Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}},
),
(
Adjoint(randn(ComplexF64, 3, 3)),
Tangent{Adjoint{ComplexF64, Matrix{ComplexF64}}},
),
(
Transpose(randn(3)),
Tangent{Transpose{Float64, Vector{Float64}}},
),
]
@test rand_tangent(rng, x) isa T_tangent
@test rand_tangent(x) isa T_tangent
@test x + rand_tangent(rng, x) isa typeof(x)
end

# Ensure struct fallback errors for non-struct types.
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
@testset "erroring cases" begin
# Ensure struct fallback errors for non-struct types.
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
end

@testset "compsition of addition" begin
x = Foo(1.5, 2, Foo(1.1, 3, [1.7, 1.4, 0.9]))
@test x + rand_tangent(x) isa typeof(x)
@test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x)
end
end