Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.jl.*.cov
*.jl.mem
/Manifest.toml
dev/

# Docs:
docs/build/
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.6"
version = "0.12.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ChainRulesCore = "0.9"
ChainRulesCore = "0.9.44"
Richardson = "1.2"
StaticArrays = "0.12, 1.0"
julia = "1"
Expand Down
97 changes: 79 additions & 18 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,39 +1,59 @@
# This file is machine-generated - editing it directly is not advised

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"]
git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a"
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.24"
version = "0.9.44"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.30.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1"
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.3"
version = "0.8.4"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "a4875e0763112d6d017126f3944f4133abb342ae"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.25.5"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FiniteDifferences]]
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
path = ".."
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.11.5"
version = "0.12.7"

[[IOCapture]]
deps = ["Logging"]
Expand All @@ -51,10 +71,22 @@ git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.1"

[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[LibGit2]]
deps = ["Printf"]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -69,30 +101,35 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MuladdMacro]]
git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68"
uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
version = "0.2.2"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714"
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.15"
version = "1.1.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
Expand All @@ -111,6 +148,10 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

Expand All @@ -120,16 +161,24 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49"
git-tree-sha1 = "c635017268fd51ed944ec429bcc4ad010bcea900"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.0.1"
version = "1.2.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[UUIDs]]
Expand All @@ -138,3 +187,15 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
14 changes: 7 additions & 7 deletions src/difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ If `(y - x) / ε` is defined, then this operation is equivalent to doing that. F
where these operations aren't defined, `difference` can still be defined without commiting
type piracy while `-` and `/` cannot.
"""
difference(::Real, ::T, ::T) where {T<:Symbol} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:AbstractChar} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:AbstractString} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:Integer} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:Symbol} = NoTangent()
difference(::Real, ::T, ::T) where {T<:AbstractChar} = NoTangent()
difference(::Real, ::T, ::T) where {T<:AbstractString} = NoTangent()
difference(::Real, ::T, ::T) where {T<:Integer} = NoTangent()

difference(ε::Real, y::T, x::T) where {T<:Number} = (y - x) / ε

difference(ε::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x)

function difference(ε::Real, y::T, x::T) where {T<:Tuple}
return Composite{T}(difference.(ε, y, x)...)
return Tangent{T}(difference.(ε, y, x)...)
end

function difference(ε::Real, ys::T, xs::T) where {T<:NamedTuple}
return Composite{T}(; map((y, x) -> difference(ε, y, x), ys, xs)...)
return Tangent{T}(; map((y, x) -> difference(ε, y, x), ys, xs)...)
end

function difference(ε::Real, y::T, x::T) where {T}
Expand All @@ -38,7 +38,7 @@ function difference(ε::Real, y::T, x::T) where {T}
tangents = map(field_names) do field_name
difference(ε, getfield(y, field_name), getfield(x, field_name))
end
return Composite{T}(; NamedTuple{field_names}(tangents)...)
return Tangent{T}(; NamedTuple{field_names}(tangents)...)
else
return NO_FIELDS
end
Expand Down
18 changes: 9 additions & 9 deletions src/rand_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ Returns a randomly generated tangent vector appropriate for the primal value `x`
"""
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)

rand_tangent(rng::AbstractRNG, x::Symbol) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::AbstractChar) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::AbstractString) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent()

rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()

rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)

Expand All @@ -20,11 +20,11 @@ rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng))
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)

function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
return Composite{T}(rand_tangent.(Ref(rng), x)...)
return Tangent{T}(rand_tangent.(Ref(rng), x)...)
end

function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
return Composite{T}(; map(x -> rand_tangent(rng, x), xs)...)
return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...)
end

function rand_tangent(rng::AbstractRNG, x::T) where {T}
Expand All @@ -37,11 +37,11 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T}
tangents = map(field_names) do field_name
rand_tangent(rng, getfield(x, field_name))
end
if all(tangent isa DoesNotExist for tangent in tangents)
if all(tangent isa NoTangent for tangent in tangents)
# if none of my fields can be perturbed then I can't be perturbed
return DoesNotExist()
return NoTangent()
else
Composite{T}(; NamedTuple{field_names}(tangents)...)
Tangent{T}(; NamedTuple{field_names}(tangents)...)
end
else
return NO_FIELDS
Expand Down
8 changes: 4 additions & 4 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ end


# ChainRulesCore Differentials
function FiniteDifferences.to_vec(x::Composite{P}) where{P}
function FiniteDifferences.to_vec(x::Tangent{P}) where{P}
x_canon = canonicalize(x) # to be safe, fill in every field and put in primal order.
x_inner = ChainRulesCore.backing(x_canon)
x_vec, back_inner = FiniteDifferences.to_vec(x_inner)
function Composite_from_vec(y_vec)
function Tangent_from_vec(y_vec)
y_back = back_inner(y_vec)
return Composite{P, typeof(y_back)}(y_back)
return Tangent{P, typeof(y_back)}(y_back)
end
return x_vec, Composite_from_vec
return x_vec, Tangent_from_vec
end

function FiniteDifferences.to_vec(x::AbstractZero)
Expand Down
Loading