Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.15"
version = "0.12.16"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/FiniteDifferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using StaticArrays

export to_vec, grad, jacobian, jvp, j′vp

include("rand_tangent.jl")
include("deprecated.jl")
include("methods.jl")
include("numerics.jl")
include("to_vec.jl")
Expand Down
36 changes: 25 additions & 11 deletions src/rand_tangent.jl → src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
function depwarn_rt()
Base.depwarn(
"FiniteDifferences.rand_tangent is deprecated, it has moved to ChainRulesTestUtils",
:rand_tangent
)
end

"""
rand_tangent([rng::AbstractRNG,] x)
Expand All @@ -7,27 +14,33 @@ Rather it is an arbitary value, that is generated using the `rng`.
"""
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)

rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent()
rand_tangent(rng::AbstractRNG, x::Symbol) = (depwarn_rt(); NoTangent())
rand_tangent(rng::AbstractRNG, x::AbstractChar) = (depwarn_rt(); NoTangent())
rand_tangent(rng::AbstractRNG, x::AbstractString) = (depwarn_rt(); NoTangent())

rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()
rand_tangent(rng::AbstractRNG, x::Integer) = (depwarn_rt(); NoTangent())

# Try and make nice numbers with short decimal representations for good error messages
# while also not biasing the sample space too much
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Number}
# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
depwarn_rt()
# multiply by 9 to give a bigger range of values tested:
# not so tightly clustered around 0.
return round(9 * randn(rng, T), sigdigits=5, base=2)
end
rand_tangent(rng::AbstractRNG, x::Float64) = rand(rng, -9:0.01:9)
rand_tangent(rng::AbstractRNG, x::Float64) = (depwarn_rt(); rand(rng, -9:0.01:9))
function rand_tangent(rng::AbstractRNG, x::ComplexF64)
depwarn_rt()
return ComplexF64(rand(rng, -9:0.1:9), rand(rng, -9:0.1:9))
end

#BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should

# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), sigdigits=5, base=2)
function rand_tangent(rng::AbstractRNG, ::BigFloat)
depwarn_rt()
# multiply by 9 to give a bigger range of values tested:
# not so tightly clustered around 0.
return round(big(9 * randn(rng)), sigdigits=5, base=2)
end

rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T} = fill(rand_tangent(x[1]))
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
Expand All @@ -53,11 +66,12 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T}
end
if all(tangent isa NoTangent for tangent in tangents)
# if none of my fields can be perturbed then I can't be perturbed
depwarn_rt()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be outside of the if statement so that both branches get it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the other branch calls rand_tangent, so will git it somewhere else.
Having it outside the if would result in it being printed multiple times

return NoTangent()
else
Tangent{T}(; NamedTuple{field_names}(tangents)...)
end
end

rand_tangent(rng::AbstractRNG, ::Type) = NoTangent()
rand_tangent(rng::AbstractRNG, ::Module) = NoTangent()
rand_tangent(rng::AbstractRNG, ::Type) = (depwarn_rt(); NoTangent())
rand_tangent(rng::AbstractRNG, ::Module) = (depwarn_rt(); NoTangent())
16 changes: 13 additions & 3 deletions test/rand_tangent.jl → test/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
using FiniteDifferences: rand_tangent
# Test struct for `rand_tangent` and `difference`.
struct Foo
a::Float64
b::Int
c::Any
end

@testset "generate_tangent" begin
# to avoid deprecation spam (and actually test deprecations) we will define a wrapper `rand_tangent` function for testing
rand_tangent(args...) = @test_deprecated FiniteDifferences.rand_tangent(args...)

@testset "rand_tangent" begin
rng = MersenneTwister(123456)

@testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [
Expand Down Expand Up @@ -89,7 +97,9 @@ using FiniteDifferences: rand_tangent

@testset "erroring cases" begin
# Ensure struct fallback errors for non-struct types.
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
@test_throws ArgumentError invoke(
FiniteDifferences.rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0
)
end

@testset "compsition of addition" begin
Expand Down
8 changes: 1 addition & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,11 @@ using Random
using StaticArrays
using Test

# Test struct for `rand_tangent` and `difference`.
struct Foo
a::Float64
b::Int
c::Any
end

Random.seed!(1)

@testset "FiniteDifferences" begin
include("rand_tangent.jl")
include("deprecated.jl")
include("methods.jl")
include("numerics.jl")
include("to_vec.jl")
Expand Down