diff --git a/Project.toml b/Project.toml index eb38d70..03d210e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.7" +version = "0.12.8" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/difference.jl b/src/difference.jl index 502007e..a35ee4f 100644 --- a/src/difference.jl +++ b/src/difference.jl @@ -18,7 +18,10 @@ difference(::Real, ::T, ::T) where {T<:Integer} = NoTangent() difference(ε::Real, y::T, x::T) where {T<:Number} = (y - x) / ε -difference(ε::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x) +# we are a bit more relaced for AbstractArrays as they naturally represent a vector space +difference(ε::Real, y::AbstractArray, x) = difference.(ε, y, x) +# resolve ambiguity +difference(ε::Real, y::T, x::T) where {T<:AbstractArray} = difference.(ε, y, x) function difference(ε::Real, y::T, x::T) where {T<:Tuple} return Tangent{T}(difference.(ε, y, x)...) diff --git a/src/rand_tangent.jl b/src/rand_tangent.jl index fef39ef..cdca3aa 100644 --- a/src/rand_tangent.jl +++ b/src/rand_tangent.jl @@ -13,11 +13,13 @@ rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent() rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T) -# TODO: right now Julia don't allow `randn(rng, BigFloat)` +# TODO: right now Julia don't allow `randn(rng, BigFloat)` # see: https://github.com/JuliaLang/julia/issues/17629 rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng)) rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x) +rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x))) +rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x))) function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple} return Tangent{T}(rand_tangent.(Ref(rng), x)...) diff --git a/test/difference.jl b/test/difference.jl index ba4f147..3b906be 100644 --- a/test/difference.jl +++ b/test/difference.jl @@ -1,13 +1,5 @@ using FiniteDifferences: rand_tangent, difference -function test_difference(ε::Real, x, dx) - y = x + ε * dx - dx_diff = difference(ε, y, x) - # TODO: `@test isapprox(dx, dx_diff)` once `isapprox` is defined appropriately - # see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/184 - @test typeof(dx) == typeof(dx_diff) -end - @testset "difference" begin @testset "Primal: $(typeof(x))" for (ε, x) in [ @@ -56,7 +48,17 @@ end (randn(), Adjoint(randn(ComplexF64, 3, 3))), (randn(), Transpose(randn(3))), ] - test_difference(ε, x, rand_tangent(x)) + # Construct a value that should be equal to the difference and check that it is + dx = rand_tangent(x) + y = x + ε * dx + dx_diff = difference(ε, y, x) + + if x isa AbstractArray{<:Number} || x isa Number + @test x + dx ≈ x + dx_diff + else + # hard to check value if don't overload `≈` so for now we just check type + @test typeof(dx) == typeof(dx_diff) + end end # Ensure struct fallback errors for non-struct types. diff --git a/test/rand_tangent.jl b/test/rand_tangent.jl index 6f2b2f9..876f91c 100644 --- a/test/rand_tangent.jl +++ b/test/rand_tangent.jl @@ -24,6 +24,11 @@ using FiniteDifferences: rand_tangent (randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}), ([randn(5, 4), 4.0], Vector{Any}), + # Wrapper Arrays + (randn(5, 4)', Adjoint{Float64, Matrix{Float64}}), + (transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}), + + # Tuples. ((4.0, ), Tangent{Tuple{Float64}}), ((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}), @@ -66,20 +71,19 @@ using FiniteDifferences: rand_tangent Hermitian(randn(ComplexF64, 1, 1)), Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}}, ), - ( - Adjoint(randn(ComplexF64, 3, 3)), - Tangent{Adjoint{ComplexF64, Matrix{ComplexF64}}}, - ), - ( - Transpose(randn(3)), - Tangent{Transpose{Float64, Vector{Float64}}}, - ), ] @test rand_tangent(rng, x) isa T_tangent @test rand_tangent(x) isa T_tangent - @test x + rand_tangent(rng, x) isa typeof(x) end - # Ensure struct fallback errors for non-struct types. - @test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0) + @testset "erroring cases" begin + # Ensure struct fallback errors for non-struct types. + @test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0) + end + + @testset "compsition of addition" begin + x = Foo(1.5, 2, Foo(1.1, 3, [1.7, 1.4, 0.9])) + @test x + rand_tangent(x) isa typeof(x) + @test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x) + end end