From 68bc950ef90788ffde4495d69f569aa1be4776d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Apr 2023 13:57:39 -0400 Subject: [PATCH 1/2] Add ProjectTo for CA --- Project.toml | 2 +- src/compat/chainrulescore.jl | 12 +++++++++--- test/autodiff_tests.jl | 6 ++++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 3deabdd4..ad8a2a38 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.13.9" +version = "0.13.10" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index f9406591..0b4b94bf 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol, Val}) +function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val}) function getproperty_adjoint(Δ) zero_x = zero(similar(x, eltype(Δ))) setproperty!(zero_x, s, Δ) @@ -8,6 +8,12 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union return getproperty(x, s), getproperty_adjoint end -ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x))) +ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x))) -ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent()) +ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent()) + +function ChainRulesCore.ProjectTo(ca::ComponentArray) + return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca)) +end + +(p::ChainRulesCore.ProjectTo{ComponentArray})(dx::AbstractArray) = ComponentArray(p.project(dx), p.axes) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 6004524d..4f270ab1 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -50,6 +50,12 @@ truth = ComponentArray(a = [32, 48], x = 156) @test out isa Vector{<:ForwardDiff.Dual} end +@testset "Projection" begin + gs_ca = Zygote.gradient(sum, ca)[1] + + @test gs_ca isa ComponentArray +end + # # This is commented out because the gradient operation itself is broken due to Zygote's inability # # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple. From d6549908a89afcc573392b74ba6c357bd7bae31d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Apr 2023 10:59:19 -0400 Subject: [PATCH 2/2] Fix tests: gradients for Int --> Float --- test/autodiff_tests.jl | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 4f270ab1..2b155c47 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -32,17 +32,9 @@ truth = ComponentArray(a = [32, 48], x = 156) @test zygote_full ≈ truth end - # Not sure why this doesn't work in v1.2, but I don't want to drop the tests for that just - # for this to work - if VERSION ≥ v"1.6" - @test ComponentArray(x=4,) == Zygote.gradient(ComponentArray(x=2,)) do c - (;c...,).x^2 - end[1] - else - @test_skip ComponentArray(x=4,) == Zygote.gradient(ComponentArray(x=2,)) do c - (;c...,).x^2 - end[1] - end + @test ComponentArray(x=4.0,) ≈ Zygote.gradient(ComponentArray(x=2,)) do c + (;c...,).x^2 + end[1] # Issue #148 ps = ComponentArray(;bias = rand(4))