Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_add!!_behaviour has strong assumptions on fields #267

Closed
theogf opened this issue Jan 24, 2023 · 2 comments
Closed

test_add!!_behaviour has strong assumptions on fields #267

theogf opened this issue Jan 24, 2023 · 2 comments

Comments

@theogf
Copy link
Contributor

theogf commented Jan 24, 2023

When calling test_rrule on a struct containing Tuple as its fields (e.g. StructArray or SArray), _test_cotangent will fail due to the impossibility of adding Tuple together.

Using broadcast on this line: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/ed9a0073ff83cb3b1f4619303e41f4dd5d8c4825/src/tangent_types/tangent.jl#L301 would solve the issue I think.

Here is a MWE.

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    collect_rrule::AbstractArray) = NoTangent(), Tangent{X}(data = ntuple(i -> Δ[i], Val(L)))
    return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)
@oxinabox
Copy link
Member

While broadcast would fix that specific case, i think better is a call to elementwise_add
which would do the same thing in the tuple case.
But I can imagine for something with other types inside it, like NamedTuple broadcast would work.
Further if one side as has a scalar and other other a Tuple, then broadcast would "work" which it shouldn't.

However, I am not 100% sure this is actually a problem.
Since the example type is not a correct tangent type.
Since it has a field which is not a valid tangent type, due to not implementing a vector space. (Which requires overloading +)
I believe, the correct code is

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    collect_rrule::AbstractArray) = NoTangent(), Tangent{X}(data = Tangent{NTuple{T, N}}(ntuple(i -> Δ[i], Val(L))...))
    return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)

Or much simpler use a natural tangent:

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    collect_rrule::AbstractArray) = NoTangent(),  Δ
    return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)

Strictly speaking one should use projection:

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    y = collect(x)
    proj = ProjectTo(y)
    collect_rrule::AbstractArray) = NoTangent(),  proj(Δ)
    return y, collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)

@theogf
Copy link
Contributor Author

theogf commented Jan 25, 2023

Thanks for the investigation, it is really helpful!

I have indeed started to look more into ProjectTo I tended to forget that StructArrays are well... arrays!

@theogf theogf closed this as completed Jan 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants