From a31a9a145a92663618d4f3f5d40f24f5ba8e8166 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 16:44:45 -0500 Subject: [PATCH 1/5] Add more functions for tracked componentarrays --- Project.toml | 2 +- ext/ComponentArraysReverseDiffExt.jl | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index aaa2d476..00af2093 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.10" +version = "0.15.11" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/ComponentArraysReverseDiffExt.jl b/ext/ComponentArraysReverseDiffExt.jl index 72f22a1b..26e82999 100644 --- a/ext/ComponentArraysReverseDiffExt.jl +++ b/ext/ComponentArraysReverseDiffExt.jl @@ -2,7 +2,7 @@ module ComponentArraysReverseDiffExt using ComponentArrays, ReverseDiff -const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA} +const TrackedComponentArray{V,D,N,DA,A,Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA} maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape) function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector) @@ -12,10 +12,10 @@ function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector) end for f in [:getindex, :view] - @eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol, Val}...) - val = $f(ReverseDiff.value(tca), inds...) - der = Base.maybeview(ReverseDiff.deriv(tca), inds...) - t = ReverseDiff.tape(tca) + @eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol,Val}...) + val = $f(ReverseDiff.value(tca), inds...) + der = Base.maybeview(ReverseDiff.deriv(tca), inds...) + t = ReverseDiff.tape(tca) return maybe_tracked_array(val, der, t, inds, tca) end end @@ -31,4 +31,13 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol) end end +function Base.propertynames(::TrackedComponentArray{V,D,N,DA,A,Tuple{Ax}}) where {V,D,N,DA,A,Ax<:ComponentArrays.AbstractAxis} + return propertynames(ComponentArrays.indexmap(Ax)) +end + +function Base.NamedTuple(tca::TrackedComponentArray) + props = propertynames(tca) + return NamedTuple{props}(getproperty(tca, p) for p in props) +end + end From 2c4fed75041f59bb7a1128547ceffdaa315545b4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 17:11:47 -0500 Subject: [PATCH 2/5] Add rrule for NamedTuple conversion --- Project.toml | 4 +++- ext/ComponentArraysTruncatedStacktracesExt.jl | 7 +++++++ src/compat/chainrulescore.jl | 12 +++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 ext/ComponentArraysTruncatedStacktracesExt.jl diff --git a/Project.toml b/Project.toml index 00af2093..1f0d2109 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -31,6 +32,7 @@ ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" ComponentArraysReverseDiffExt = "ReverseDiff" ComponentArraysSciMLBaseExt = "SciMLBase" ComponentArraysTrackerExt = "Tracker" +ComponentArraysTruncatedStacktracesExt = "TruncatedStacktraces" ComponentArraysZygoteExt = "Zygote" [compat] @@ -45,8 +47,8 @@ PackageExtensionCompat = "1" RecursiveArrayTools = "2, 3" ReverseDiff = "1" SciMLBase = "1, 2" -StaticArraysCore = "1" StaticArrayInterface = "1" +StaticArraysCore = "1" Tracker = "0.2" Zygote = "0.6" julia = "1.6" diff --git a/ext/ComponentArraysTruncatedStacktracesExt.jl b/ext/ComponentArraysTruncatedStacktracesExt.jl new file mode 100644 index 00000000..47a212e6 --- /dev/null +++ b/ext/ComponentArraysTruncatedStacktracesExt.jl @@ -0,0 +1,7 @@ +module ComponentArraysTruncatedStacktracesExt + +using ComponentArrays, TruncatedStacktraces + +@truncate_stacktrace ComponentArray 1 + +end \ No newline at end of file diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index 39594a96..0e6b9d79 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -41,7 +41,17 @@ end # Prevent double projection (p::ChainRulesCore.ProjectTo{ComponentArray})(dx::ComponentArray) = dx -function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A, <:NamedTuple}) where {A} +function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A,<:NamedTuple}) where {A} nt = Functors.fmap(ChainRulesCore.backing, ChainRulesCore.backing(t)) return ComponentArray(nt) end + +function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray} + y = CA(nt) + + function ∇NamedTupleToComponentArray(Δ::ComponentArray) + return ChainRulesCore.NoTangent(), NamedTuple(Δ) + end + + return y, ∇NamedTupleToComponentArray +end From 23b996f93411c2de07b5c9b5cb61ef37224282ba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 17:16:33 -0500 Subject: [PATCH 3/5] Add Functors compatibility --- src/ComponentArrays.jl | 2 ++ src/compat/functors.jl | 4 ++++ 2 files changed, 6 insertions(+) create mode 100644 src/compat/functors.jl diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 7a83d923..1fcb90f1 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -52,6 +52,8 @@ include("compat/chainrulescore.jl") include("compat/static_arrays.jl") export @static_unpack +include("compat/functors.jl") + import PackageExtensionCompat: @require_extensions function __init__() @require_extensions diff --git a/src/compat/functors.jl b/src/compat/functors.jl new file mode 100644 index 00000000..b0074098 --- /dev/null +++ b/src/compat/functors.jl @@ -0,0 +1,4 @@ +function Functors.functor(::Type{<:ComponentArray}, c) + return ( + NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))), ComponentArray) +end From ab718c6121ba4498bb249256b3e56c4106fc5d55 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 17:52:53 -0500 Subject: [PATCH 4/5] Add Optimisers --- Project.toml | 4 ++++ ext/ComponentArraysOptimisersExt.jl | 21 +++++++++++++++++++ ext/ComponentArraysReverseDiffExt.jl | 4 ++++ ext/ComponentArraysTruncatedStacktracesExt.jl | 3 ++- src/utils.jl | 2 ++ 5 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 ext/ComponentArraysOptimisersExt.jl diff --git a/Project.toml b/Project.toml index 1f0d2109..9026e3fd 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -28,6 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ComponentArraysAdaptExt = "Adapt" ComponentArraysConstructionBaseExt = "ConstructionBase" ComponentArraysGPUArraysExt = "GPUArrays" +ComponentArraysOptimisersExt = "Optimisers" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" ComponentArraysReverseDiffExt = "ReverseDiff" ComponentArraysSciMLBaseExt = "SciMLBase" @@ -58,9 +60,11 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/ComponentArraysOptimisersExt.jl b/ext/ComponentArraysOptimisersExt.jl new file mode 100644 index 00000000..1240aefe --- /dev/null +++ b/ext/ComponentArraysOptimisersExt.jl @@ -0,0 +1,21 @@ +module ComponentArraysOptimisersExt + +using ComponentArrays, Optimisers + +# Optimisers can handle componentarrays by default, but we can vectorize the entire +# operation here instead of doing multiple smaller operations +Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps)) + +function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray) + gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff + tree, ps_new = Optimisers.update(tree, getdata(ps), gs_flat) + return tree, ComponentArray(ps_new, getaxes(ps)) +end + +function Optimisers.update!(tree::Optimisers.Leaf, ps::ComponentArray, gs::ComponentArray) + gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff + tree, ps_new = Optimisers.update!(tree, getdata(ps), gs_flat) + return tree, ComponentArray(ps_new, getaxes(ps)) +end + +end \ No newline at end of file diff --git a/ext/ComponentArraysReverseDiffExt.jl b/ext/ComponentArraysReverseDiffExt.jl index 26e82999..8a9a473b 100644 --- a/ext/ComponentArraysReverseDiffExt.jl +++ b/ext/ComponentArraysReverseDiffExt.jl @@ -40,4 +40,8 @@ function Base.NamedTuple(tca::TrackedComponentArray) return NamedTuple{props}(getproperty(tca, p) for p in props) end +@inline ComponentArrays.__value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) +@inline ComponentArrays.__value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) +@inline ComponentArrays.__value(x::TrackedComponentArray) = ComponentArray(ComponentArrays.__value(getdata(x)), getaxes(x)) + end diff --git a/ext/ComponentArraysTruncatedStacktracesExt.jl b/ext/ComponentArraysTruncatedStacktracesExt.jl index 47a212e6..2b9d57e6 100644 --- a/ext/ComponentArraysTruncatedStacktracesExt.jl +++ b/ext/ComponentArraysTruncatedStacktracesExt.jl @@ -1,6 +1,7 @@ module ComponentArraysTruncatedStacktracesExt -using ComponentArrays, TruncatedStacktraces +using ComponentArrays +import TruncatedStacktraces: @truncate_stacktrace @truncate_stacktrace ComponentArray 1 diff --git a/src/utils.jl b/src/utils.jl index 7a600005..b11c59dd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -50,3 +50,5 @@ recursive_eltype(x::AbstractArray{<:Any}) = isempty(x) ? Base.Bottom : mapreduce recursive_eltype(x::Dict) = isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, values(x)) recursive_eltype(::AbstractArray{T,N}) where {T<:Number, N} = T recursive_eltype(x) = typeof(x) + +@inline __value(x) = x From a4e5293b49dd5dc1d891cfbcd122c3e654dccd17 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 19:32:33 -0500 Subject: [PATCH 5/5] Add tests --- .github/workflows/ci.yml | 2 +- Project.toml | 4 +++ src/compat/chainrulescore.jl | 9 +++++++ src/compat/functors.jl | 5 +--- test/Project.toml | 2 ++ test/autodiff_tests.jl | 52 +++++++++++++++++++++++++++--------- test/runtests.jl | 10 +++++++ 7 files changed, 66 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed8067a9..96e03ea0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,8 +18,8 @@ jobs: - '1.6' - '1.8' - '1.9' + - '1.10' - '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia. - - '1.10.0-beta3' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index 9026e3fd..33ba6269 100644 --- a/Project.toml +++ b/Project.toml @@ -45,13 +45,17 @@ ConstructionBase = "1" ForwardDiff = "0.10" Functors = "0.4.4" GPUArrays = "8, 9, 10" +LinearAlgebra = "1" +Optimisers = "0.3" PackageExtensionCompat = "1" RecursiveArrayTools = "2, 3" ReverseDiff = "1" SciMLBase = "1, 2" StaticArrayInterface = "1" StaticArraysCore = "1" +Test = "1" Tracker = "0.2" +TruncatedStacktraces = "1.4" Zygote = "0.6" julia = "1.6" diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index 0e6b9d79..9f9211bd 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -49,6 +49,15 @@ end function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray} y = CA(nt) + function ∇NamedTupleToComponentArray(Δ::AbstractArray) + if length(Δ) == length(y) + return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y))) + end + error("Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " * + "of shape $(size(y)) & type $(typeof(y))") + return nothing + end + function ∇NamedTupleToComponentArray(Δ::ComponentArray) return ChainRulesCore.NoTangent(), NamedTuple(Δ) end diff --git a/src/compat/functors.jl b/src/compat/functors.jl index b0074098..3978ca12 100644 --- a/src/compat/functors.jl +++ b/src/compat/functors.jl @@ -1,4 +1 @@ -function Functors.functor(::Type{<:ComponentArray}, c) - return ( - NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))), ComponentArray) -end +Functors.functor(::Type{<:ComponentVector}, c) = NamedTuple(c), ComponentVector diff --git a/test/Project.toml b/test/Project.toml index daa203a4..83210c59 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,9 +9,11 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 091784a0..b088642c 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -1,5 +1,5 @@ import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote - +using Optimisers using Test F(a, x) = sum(abs2, a) * x^3 @@ -38,6 +38,22 @@ truth = ComponentArray(a = [32, 48], x = 156) @test out isa Vector{<:ForwardDiff.Dual} end +@testset "Optimisers Update" begin + ca_ = deepcopy(ca) + opt_st = Optimisers.setup(Adam(0.01), ca_) + gs_zyg = only(Zygote.gradient(F_idx_val, ca_)) + @test !(last(Optimisers.update(opt_st, ca_, gs_zyg)) ≈ ca) + Optimisers.update!(opt_st, ca_, gs_zyg) + @test !(ca_ ≈ ca) + + ca_ = deepcopy(ca) + opt_st = Optimisers.setup(Adam(0.01), ca_) + gs_rdiff = ReverseDiff.gradient(F_idx_val, ca_) + @test !(last(Optimisers.update(opt_st, ca_, gs_rdiff)) ≈ ca) + Optimisers.update!(opt_st, ca_, gs_rdiff) + @test !(ca_ ≈ ca) +end + @testset "Projection" begin gs_ca = Zygote.gradient(sum, ca)[1] @@ -76,18 +92,28 @@ end @test ∂r ≈ ∂r_ca end +function F_prop(x) + @assert propertynames(x) == (:x, :y) + return sum(abs2, x.x .- x.y) +end + +@testset "Preserve Properties" begin + x = ComponentArray(; x = [1.0, 5.0], y = [3.0, 4.0]) -# # This is commented out because the gradient operation itself is broken due to Zygote's inability -# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple. -# # It would be nice to support this eventually, so I'll just leave this commented (because @test_broken -# # wouldn't work here because the error happens before the test) -# @testset "Issues" begin -# function mysum(x::AbstractVector) -# y = ComponentVector(x=x) -# return sum(y) -# end + gs_z = only(Zygote.gradient(F_prop, x)) + gs_rdiff = ReverseDiff.gradient(F_prop, x) -# Δ = Zygote.gradient(mysum, rand(10)) + @test gs_z ≈ gs_rdiff +end + +@testset "Issues" begin + function mysum(x::AbstractVector) + y = ComponentVector(x=x) + z = ComponentVector(; z = x .^ 2) + return sum(y) + sum(abs2, z) + end -# @test Δ isa Vector{Float64} -# end + Δ = only(Zygote.gradient(mysum, rand(10))) + + @test Δ isa AbstractVector{Float64} +end diff --git a/test/runtests.jl b/test/runtests.jl index 4f4d4167..3e7b54eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,8 @@ using StaticArrays using OffsetArrays using Test using Unitful +using Functors +import TruncatedStacktraces # This is loaded just to trigger the extension package ## Test setup @@ -690,6 +692,14 @@ end @test_throws ArgumentError axpby!(2, x, 3, y) end +@testset "Functors" begin + for carray in (ca, ca_Float32, ca_MVector, ca_SVector, ca_composed, ca2, caa) + θ, re = Functors.functor(carray) + @test θ isa NamedTuple + @test re(θ) == carray + end +end + @testset "Autodiff" begin include("autodiff_tests.jl") end