diff --git a/.gitignore b/.gitignore index 6caafe5..084bd37 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.jl.*.cov *.jl.mem /Manifest.toml +dev/ # Docs: docs/build/ diff --git a/Project.toml b/Project.toml index 95b00a9..eb38d70 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.6" +version = "0.12.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] -ChainRulesCore = "0.9" +ChainRulesCore = "0.9.44" Richardson = "1.2" StaticArrays = "0.12, 1.0" julia = "1" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index db7ead3..8f4ef47 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,27 +1,43 @@ # This file is machine-generated - editing it directly is not advised +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] -deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] -git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a" +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.24" +version = "0.9.44" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.30.0" [[Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + [[Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1" +git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.3" +version = "0.8.4" [[Documenter]] deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] @@ -29,11 +45,15 @@ git-tree-sha1 = "a4875e0763112d6d017126f3944f4133abb342ae" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" version = "0.25.5" +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + [[FiniteDifferences]] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"] path = ".." uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.11.5" +version = "0.12.7" [[IOCapture]] deps = ["Logging"] @@ -51,10 +71,22 @@ git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.1" +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + [[LibGit2]] -deps = ["Printf"] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -69,22 +101,27 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" -[[MuladdMacro]] -git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.2" +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" [[Parsers]] deps = ["Dates"] -git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714" +git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.15" +version = "1.1.0" [[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Printf]] @@ -92,7 +129,7 @@ deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[Random]] @@ -111,6 +148,10 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -120,16 +161,24 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49" +git-tree-sha1 = "c635017268fd51ed944ec429bcc4ad010bcea900" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.0.1" +version = "1.2.0" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + [[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[UUIDs]] @@ -138,3 +187,15 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/src/difference.jl b/src/difference.jl index 1cfe76d..502007e 100644 --- a/src/difference.jl +++ b/src/difference.jl @@ -11,21 +11,21 @@ If `(y - x) / ε` is defined, then this operation is equivalent to doing that. F where these operations aren't defined, `difference` can still be defined without commiting type piracy while `-` and `/` cannot. """ -difference(::Real, ::T, ::T) where {T<:Symbol} = DoesNotExist() -difference(::Real, ::T, ::T) where {T<:AbstractChar} = DoesNotExist() -difference(::Real, ::T, ::T) where {T<:AbstractString} = DoesNotExist() -difference(::Real, ::T, ::T) where {T<:Integer} = DoesNotExist() +difference(::Real, ::T, ::T) where {T<:Symbol} = NoTangent() +difference(::Real, ::T, ::T) where {T<:AbstractChar} = NoTangent() +difference(::Real, ::T, ::T) where {T<:AbstractString} = NoTangent() +difference(::Real, ::T, ::T) where {T<:Integer} = NoTangent() difference(ε::Real, y::T, x::T) where {T<:Number} = (y - x) / ε difference(ε::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x) function difference(ε::Real, y::T, x::T) where {T<:Tuple} - return Composite{T}(difference.(ε, y, x)...) + return Tangent{T}(difference.(ε, y, x)...) end function difference(ε::Real, ys::T, xs::T) where {T<:NamedTuple} - return Composite{T}(; map((y, x) -> difference(ε, y, x), ys, xs)...) + return Tangent{T}(; map((y, x) -> difference(ε, y, x), ys, xs)...) end function difference(ε::Real, y::T, x::T) where {T} @@ -38,7 +38,7 @@ function difference(ε::Real, y::T, x::T) where {T} tangents = map(field_names) do field_name difference(ε, getfield(y, field_name), getfield(x, field_name)) end - return Composite{T}(; NamedTuple{field_names}(tangents)...) + return Tangent{T}(; NamedTuple{field_names}(tangents)...) else return NO_FIELDS end diff --git a/src/rand_tangent.jl b/src/rand_tangent.jl index c6967a9..fef39ef 100644 --- a/src/rand_tangent.jl +++ b/src/rand_tangent.jl @@ -5,11 +5,11 @@ Returns a randomly generated tangent vector appropriate for the primal value `x` """ rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x) -rand_tangent(rng::AbstractRNG, x::Symbol) = DoesNotExist() -rand_tangent(rng::AbstractRNG, x::AbstractChar) = DoesNotExist() -rand_tangent(rng::AbstractRNG, x::AbstractString) = DoesNotExist() +rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent() +rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent() +rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent() -rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist() +rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent() rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T) @@ -20,11 +20,11 @@ rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng)) rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x) function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple} - return Composite{T}(rand_tangent.(Ref(rng), x)...) + return Tangent{T}(rand_tangent.(Ref(rng), x)...) end function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple} - return Composite{T}(; map(x -> rand_tangent(rng, x), xs)...) + return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...) end function rand_tangent(rng::AbstractRNG, x::T) where {T} @@ -37,11 +37,11 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T} tangents = map(field_names) do field_name rand_tangent(rng, getfield(x, field_name)) end - if all(tangent isa DoesNotExist for tangent in tangents) + if all(tangent isa NoTangent for tangent in tangents) # if none of my fields can be perturbed then I can't be perturbed - return DoesNotExist() + return NoTangent() else - Composite{T}(; NamedTuple{field_names}(tangents)...) + Tangent{T}(; NamedTuple{field_names}(tangents)...) end else return NO_FIELDS diff --git a/src/to_vec.jl b/src/to_vec.jl index 79c7d7c..3395106 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -176,15 +176,15 @@ end # ChainRulesCore Differentials -function FiniteDifferences.to_vec(x::Composite{P}) where{P} +function FiniteDifferences.to_vec(x::Tangent{P}) where{P} x_canon = canonicalize(x) # to be safe, fill in every field and put in primal order. x_inner = ChainRulesCore.backing(x_canon) x_vec, back_inner = FiniteDifferences.to_vec(x_inner) - function Composite_from_vec(y_vec) + function Tangent_from_vec(y_vec) y_back = back_inner(y_vec) - return Composite{P, typeof(y_back)}(y_back) + return Tangent{P, typeof(y_back)}(y_back) end - return x_vec, Composite_from_vec + return x_vec, Tangent_from_vec end function FiniteDifferences.to_vec(x::AbstractZero) diff --git a/test/rand_tangent.jl b/test/rand_tangent.jl index 3104b67..6f2b2f9 100644 --- a/test/rand_tangent.jl +++ b/test/rand_tangent.jl @@ -6,11 +6,11 @@ using FiniteDifferences: rand_tangent @testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [ # Things without sensible tangents. - ("hi", DoesNotExist), - ('a', DoesNotExist), - (:a, DoesNotExist), - (true, DoesNotExist), - (4, DoesNotExist), + ("hi", NoTangent), + ('a', NoTangent), + (:a, NoTangent), + (true, NoTangent), + (4, NoTangent), # Numbers. (5.0, Float64), @@ -25,54 +25,54 @@ using FiniteDifferences: rand_tangent ([randn(5, 4), 4.0], Vector{Any}), # Tuples. - ((4.0, ), Composite{Tuple{Float64}}), - ((5.0, randn(3)), Composite{Tuple{Float64, Vector{Float64}}}), + ((4.0, ), Tangent{Tuple{Float64}}), + ((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}), # NamedTuples. - ((a=4.0, ), Composite{NamedTuple{(:a,), Tuple{Float64}}}), - ((a=5.0, b=1), Composite{NamedTuple{(:a, :b), Tuple{Float64, Int}}}), + ((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}), + ((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}), # structs. - (Foo(5.0, 4, rand(rng, 3)), Composite{Foo}), - (Foo(4.0, 3, Foo(5.0, 2, 4)), Composite{Foo}), + (Foo(5.0, 4, rand(rng, 3)), Tangent{Foo}), + (Foo(4.0, 3, Foo(5.0, 2, 4)), Tangent{Foo}), (sin, typeof(NO_FIELDS)), - # all fields DoesNotExist implies DoesNotExist - (Pair(:a, "b"), DoesNotExist), - (1:10, DoesNotExist), - (1:2:10, DoesNotExist), + # all fields NoTangent implies NoTangent + (Pair(:a, "b"), NoTangent), + (1:10, NoTangent), + (1:2:10, NoTangent), # LinearAlgebra types (also just structs). ( UpperTriangular(randn(3, 3)), - Composite{UpperTriangular{Float64, Matrix{Float64}}}, + Tangent{UpperTriangular{Float64, Matrix{Float64}}}, ), ( Diagonal(randn(2)), - Composite{Diagonal{Float64, Vector{Float64}}}, + Tangent{Diagonal{Float64, Vector{Float64}}}, ), ( SVector{2, Float64}(1.0, 2.0), - Composite{typeof(SVector{2, Float64}(1.0, 2.0))}, + Tangent{typeof(SVector{2, Float64}(1.0, 2.0))}, ), ( SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0), - Composite{typeof(SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0))}, + Tangent{typeof(SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0))}, ), ( Symmetric(randn(2, 2)), - Composite{Symmetric{Float64, Matrix{Float64}}}, + Tangent{Symmetric{Float64, Matrix{Float64}}}, ), ( Hermitian(randn(ComplexF64, 1, 1)), - Composite{Hermitian{ComplexF64, Matrix{ComplexF64}}}, + Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}}, ), ( Adjoint(randn(ComplexF64, 3, 3)), - Composite{Adjoint{ComplexF64, Matrix{ComplexF64}}}, + Tangent{Adjoint{ComplexF64, Matrix{ComplexF64}}}, ), ( Transpose(randn(3)), - Composite{Transpose{Float64, Vector{Float64}}}, + Tangent{Transpose{Float64, Vector{Float64}}}, ), ] @test rand_tangent(rng, x) isa T_tangent diff --git a/test/to_vec.jl b/test/to_vec.jl index f8c4cde..abf1019 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -20,7 +20,7 @@ end Base.size(x::FillVector) = (x.len,) Base.getindex(x::FillVector, n::Int) = x.x -# For testing Composite{ThreeFields} +# For testing Tangent{ThreeFields} struct ThreeFields a b @@ -142,39 +142,39 @@ end end @testset "ChainRulesCore Differentials" begin - @testset "Composite{Tuple}" begin + @testset "Tangent{Tuple}" begin @testset "basic" begin x_tup = (1.0, 2.0, 3.0) - x_comp = Composite{typeof(x_tup)}(x_tup...) + x_comp = Tangent{typeof(x_tup)}(x_tup...) test_to_vec(x_comp) end @testset "nested" begin x_inner = (2, 3) x_outer = (1, x_inner) - x_comp = Composite{typeof(x_outer)}(1, Composite{typeof(x_inner)}(2, 3)) + x_comp = Tangent{typeof(x_outer)}(1, Tangent{typeof(x_inner)}(2, 3)) test_to_vec(x_comp; check_inferred=false) end end - @testset "Composite Struct" begin + @testset "Tangent Struct" begin @testset "NamedTuple basic" begin nt = (; a=1.0, b=20.0) - comp = Composite{typeof(nt)}(; nt...) + comp = Tangent{typeof(nt)}(; nt...) test_to_vec(comp) end @testset "Struct" begin - test_to_vec(Composite{ThreeFields}(; a=10.0, b=20.0, c=30.0)) - test_to_vec(Composite{ThreeFields}(; a=10.0, b=20.0,)) # broken on Julia 1.6.0, fixed on 1.6.1 - test_to_vec(Composite{ThreeFields}(; a=10.0, c=30.0)) - test_to_vec(Composite{ThreeFields}(; c=30.0, a=10.0, b=20.0)) + test_to_vec(Tangent{ThreeFields}(; a=10.0, b=20.0, c=30.0)) + test_to_vec(Tangent{ThreeFields}(; a=10.0, b=20.0,)) # broken on Julia 1.6.0, fixed on 1.6.1 + test_to_vec(Tangent{ThreeFields}(; a=10.0, c=30.0)) + test_to_vec(Tangent{ThreeFields}(; c=30.0, a=10.0, b=20.0)) end end @testset "AbstractZero" begin - test_to_vec(Zero()) - test_to_vec(DoesNotExist()) + test_to_vec(ZeroTangent()) + test_to_vec(NoTangent()) end end