diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed8067a9..96e03ea0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,8 +18,8 @@ jobs: - '1.6' - '1.8' - '1.9' + - '1.10' - '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia. - - '1.10.0-beta3' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index aaa2d476..33ba6269 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.10" +version = "0.15.11" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -17,20 +17,24 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ComponentArraysAdaptExt = "Adapt" ComponentArraysConstructionBaseExt = "ConstructionBase" ComponentArraysGPUArraysExt = "GPUArrays" +ComponentArraysOptimisersExt = "Optimisers" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" ComponentArraysReverseDiffExt = "ReverseDiff" ComponentArraysSciMLBaseExt = "SciMLBase" ComponentArraysTrackerExt = "Tracker" +ComponentArraysTruncatedStacktracesExt = "TruncatedStacktraces" ComponentArraysZygoteExt = "Zygote" [compat] @@ -41,13 +45,17 @@ ConstructionBase = "1" ForwardDiff = "0.10" Functors = "0.4.4" GPUArrays = "8, 9, 10" +LinearAlgebra = "1" +Optimisers = "0.3" PackageExtensionCompat = "1" RecursiveArrayTools = "2, 3" ReverseDiff = "1" SciMLBase = "1, 2" -StaticArraysCore = "1" StaticArrayInterface = "1" +StaticArraysCore = "1" +Test = "1" Tracker = "0.2" +TruncatedStacktraces = "1.4" Zygote = "0.6" julia = "1.6" @@ -56,9 +64,11 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/ComponentArraysOptimisersExt.jl b/ext/ComponentArraysOptimisersExt.jl new file mode 100644 index 00000000..1240aefe --- /dev/null +++ b/ext/ComponentArraysOptimisersExt.jl @@ -0,0 +1,21 @@ +module ComponentArraysOptimisersExt + +using ComponentArrays, Optimisers + +# Optimisers can handle componentarrays by default, but we can vectorize the entire +# operation here instead of doing multiple smaller operations +Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps)) + +function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray) + gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff + tree, ps_new = Optimisers.update(tree, getdata(ps), gs_flat) + return tree, ComponentArray(ps_new, getaxes(ps)) +end + +function Optimisers.update!(tree::Optimisers.Leaf, ps::ComponentArray, gs::ComponentArray) + gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff + tree, ps_new = Optimisers.update!(tree, getdata(ps), gs_flat) + return tree, ComponentArray(ps_new, getaxes(ps)) +end + +end \ No newline at end of file diff --git a/ext/ComponentArraysReverseDiffExt.jl b/ext/ComponentArraysReverseDiffExt.jl index 72f22a1b..8a9a473b 100644 --- a/ext/ComponentArraysReverseDiffExt.jl +++ b/ext/ComponentArraysReverseDiffExt.jl @@ -2,7 +2,7 @@ module ComponentArraysReverseDiffExt using ComponentArrays, ReverseDiff -const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA} +const TrackedComponentArray{V,D,N,DA,A,Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA} maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape) function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector) @@ -12,10 +12,10 @@ function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector) end for f in [:getindex, :view] - @eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol, Val}...) - val = $f(ReverseDiff.value(tca), inds...) - der = Base.maybeview(ReverseDiff.deriv(tca), inds...) - t = ReverseDiff.tape(tca) + @eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol,Val}...) + val = $f(ReverseDiff.value(tca), inds...) + der = Base.maybeview(ReverseDiff.deriv(tca), inds...) + t = ReverseDiff.tape(tca) return maybe_tracked_array(val, der, t, inds, tca) end end @@ -31,4 +31,17 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol) end end +function Base.propertynames(::TrackedComponentArray{V,D,N,DA,A,Tuple{Ax}}) where {V,D,N,DA,A,Ax<:ComponentArrays.AbstractAxis} + return propertynames(ComponentArrays.indexmap(Ax)) +end + +function Base.NamedTuple(tca::TrackedComponentArray) + props = propertynames(tca) + return NamedTuple{props}(getproperty(tca, p) for p in props) +end + +@inline ComponentArrays.__value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) +@inline ComponentArrays.__value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) +@inline ComponentArrays.__value(x::TrackedComponentArray) = ComponentArray(ComponentArrays.__value(getdata(x)), getaxes(x)) + end diff --git a/ext/ComponentArraysTruncatedStacktracesExt.jl b/ext/ComponentArraysTruncatedStacktracesExt.jl new file mode 100644 index 00000000..2b9d57e6 --- /dev/null +++ b/ext/ComponentArraysTruncatedStacktracesExt.jl @@ -0,0 +1,8 @@ +module ComponentArraysTruncatedStacktracesExt + +using ComponentArrays +import TruncatedStacktraces: @truncate_stacktrace + +@truncate_stacktrace ComponentArray 1 + +end \ No newline at end of file diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 7a83d923..1fcb90f1 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -52,6 +52,8 @@ include("compat/chainrulescore.jl") include("compat/static_arrays.jl") export @static_unpack +include("compat/functors.jl") + import PackageExtensionCompat: @require_extensions function __init__() @require_extensions diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index 39594a96..9f9211bd 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -41,7 +41,26 @@ end # Prevent double projection (p::ChainRulesCore.ProjectTo{ComponentArray})(dx::ComponentArray) = dx -function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A, <:NamedTuple}) where {A} +function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A,<:NamedTuple}) where {A} nt = Functors.fmap(ChainRulesCore.backing, ChainRulesCore.backing(t)) return ComponentArray(nt) end + +function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray} + y = CA(nt) + + function ∇NamedTupleToComponentArray(Δ::AbstractArray) + if length(Δ) == length(y) + return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y))) + end + error("Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " * + "of shape $(size(y)) & type $(typeof(y))") + return nothing + end + + function ∇NamedTupleToComponentArray(Δ::ComponentArray) + return ChainRulesCore.NoTangent(), NamedTuple(Δ) + end + + return y, ∇NamedTupleToComponentArray +end diff --git a/src/compat/functors.jl b/src/compat/functors.jl new file mode 100644 index 00000000..3978ca12 --- /dev/null +++ b/src/compat/functors.jl @@ -0,0 +1 @@ +Functors.functor(::Type{<:ComponentVector}, c) = NamedTuple(c), ComponentVector diff --git a/src/utils.jl b/src/utils.jl index 7a600005..b11c59dd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -50,3 +50,5 @@ recursive_eltype(x::AbstractArray{<:Any}) = isempty(x) ? Base.Bottom : mapreduce recursive_eltype(x::Dict) = isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, values(x)) recursive_eltype(::AbstractArray{T,N}) where {T<:Number, N} = T recursive_eltype(x) = typeof(x) + +@inline __value(x) = x diff --git a/test/Project.toml b/test/Project.toml index daa203a4..83210c59 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,9 +9,11 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 091784a0..b088642c 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -1,5 +1,5 @@ import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote - +using Optimisers using Test F(a, x) = sum(abs2, a) * x^3 @@ -38,6 +38,22 @@ truth = ComponentArray(a = [32, 48], x = 156) @test out isa Vector{<:ForwardDiff.Dual} end +@testset "Optimisers Update" begin + ca_ = deepcopy(ca) + opt_st = Optimisers.setup(Adam(0.01), ca_) + gs_zyg = only(Zygote.gradient(F_idx_val, ca_)) + @test !(last(Optimisers.update(opt_st, ca_, gs_zyg)) ≈ ca) + Optimisers.update!(opt_st, ca_, gs_zyg) + @test !(ca_ ≈ ca) + + ca_ = deepcopy(ca) + opt_st = Optimisers.setup(Adam(0.01), ca_) + gs_rdiff = ReverseDiff.gradient(F_idx_val, ca_) + @test !(last(Optimisers.update(opt_st, ca_, gs_rdiff)) ≈ ca) + Optimisers.update!(opt_st, ca_, gs_rdiff) + @test !(ca_ ≈ ca) +end + @testset "Projection" begin gs_ca = Zygote.gradient(sum, ca)[1] @@ -76,18 +92,28 @@ end @test ∂r ≈ ∂r_ca end +function F_prop(x) + @assert propertynames(x) == (:x, :y) + return sum(abs2, x.x .- x.y) +end + +@testset "Preserve Properties" begin + x = ComponentArray(; x = [1.0, 5.0], y = [3.0, 4.0]) -# # This is commented out because the gradient operation itself is broken due to Zygote's inability -# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple. -# # It would be nice to support this eventually, so I'll just leave this commented (because @test_broken -# # wouldn't work here because the error happens before the test) -# @testset "Issues" begin -# function mysum(x::AbstractVector) -# y = ComponentVector(x=x) -# return sum(y) -# end + gs_z = only(Zygote.gradient(F_prop, x)) + gs_rdiff = ReverseDiff.gradient(F_prop, x) -# Δ = Zygote.gradient(mysum, rand(10)) + @test gs_z ≈ gs_rdiff +end + +@testset "Issues" begin + function mysum(x::AbstractVector) + y = ComponentVector(x=x) + z = ComponentVector(; z = x .^ 2) + return sum(y) + sum(abs2, z) + end -# @test Δ isa Vector{Float64} -# end + Δ = only(Zygote.gradient(mysum, rand(10))) + + @test Δ isa AbstractVector{Float64} +end diff --git a/test/runtests.jl b/test/runtests.jl index 4f4d4167..3e7b54eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,8 @@ using StaticArrays using OffsetArrays using Test using Unitful +using Functors +import TruncatedStacktraces # This is loaded just to trigger the extension package ## Test setup @@ -690,6 +692,14 @@ end @test_throws ArgumentError axpby!(2, x, 3, y) end +@testset "Functors" begin + for carray in (ca, ca_Float32, ca_MVector, ca_SVector, ca_composed, ca2, caa) + θ, re = Functors.functor(carray) + @test θ isa NamedTuple + @test re(θ) == carray + end +end + @testset "Autodiff" begin include("autodiff_tests.jl") end