diff --git a/Project.toml b/Project.toml index bfa4222e..ed964167 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.16" +version = "0.15.17" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/ComponentArraysTrackerExt.jl b/ext/ComponentArraysTrackerExt.jl index b0e68413..405fb0cb 100644 --- a/ext/ComponentArraysTrackerExt.jl +++ b/ext/ComponentArraysTrackerExt.jl @@ -1,5 +1,6 @@ module ComponentArraysTrackerExt +using ArrayInterface: ArrayInterface using ComponentArrays, Tracker function Tracker.param(ca::ComponentArray) @@ -34,4 +35,10 @@ end return ComponentArrays._getindex(Base.getindex, x, v) end +function ArrayInterface.restructure(x::ComponentVector, + y::ComponentVector{T, <:TrackedArray}) where {T} + getaxes(x) == getaxes(y) || error("Axes must match") + return y +end + end diff --git a/test/Project.toml b/test/Project.toml index 83210c59..3776b7e2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 68b371de..0b8c9683 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -1,5 +1,5 @@ import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote -using Optimisers +using Optimisers, ArrayInterface using Test F(a, x) = sum(abs2, a) * x^3 @@ -127,3 +127,8 @@ end @test eltype(getdata(ps_data)) <: Float64 end +@testset "ArrayInterface restructure TrackedArray" begin + ps = ComponentArray(; a = rand(2), b = (; c = rand(2))) + ps_tracked = Tracker.param(ps) + @test ArrayInterface.restructure(ps, ps_tracked) isa ComponentVector{<:Any, <:Tracker.TrackedArray} +end