From 587fb7a6aef4c4b4757207ee783042e140268831 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 May 2024 14:15:44 -0400 Subject: [PATCH 1/2] Add Tracker.data --- Project.toml | 2 +- ext/ComponentArraysTrackerExt.jl | 2 ++ test/autodiff_tests.jl | 10 ++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 77618dd0..79eaf6e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.12" +version = "0.15.13" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/ComponentArraysTrackerExt.jl b/ext/ComponentArraysTrackerExt.jl index 55754cb2..b0e68413 100644 --- a/ext/ComponentArraysTrackerExt.jl +++ b/ext/ComponentArraysTrackerExt.jl @@ -10,6 +10,8 @@ end Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca)) +Tracker.data(ca::ComponentArray) = ComponentArray(Tracker.data(getdata(ca)), getaxes(ca)) + function Base.materialize(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Nothing, typeof(zero), <:Tuple{<:ComponentVector}}) ca = first(bc.args) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index b088642c..e39d462e 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -117,3 +117,13 @@ end @test Δ isa AbstractVector{Float64} end + +@testset "Tracker untrack" begin + ps = Tracker.param(ComponentArray(; a = rand(2))) + @test eltype(getdata(ps)) isa Tracker.TrackerReal + + ps_data = Tracker.data(ps) + @test !(eltype(getdata(ps_data)) isa Tracker.TrackedReal) +end + + From 6e72b9865af0e242632608ca731cb4ced3b576e5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 May 2024 14:16:17 -0400 Subject: [PATCH 2/2] Add Tracker.data --- test/autodiff_tests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index e39d462e..68b371de 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -120,10 +120,10 @@ end @testset "Tracker untrack" begin ps = Tracker.param(ComponentArray(; a = rand(2))) - @test eltype(getdata(ps)) isa Tracker.TrackerReal + @test eltype(getdata(ps)) <: Tracker.TrackedReal{Float64} ps_data = Tracker.data(ps) - @test !(eltype(getdata(ps_data)) isa Tracker.TrackedReal) + @test !(eltype(getdata(ps_data)) <: Tracker.TrackedReal{Float64}) + @test eltype(getdata(ps_data)) <: Float64 end -