Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array getindex rule unable to handle Zero types and NotImplemented #697

Open
ToucheSir opened this issue Mar 12, 2023 · 7 comments
Open

Comments

@ToucheSir
Copy link
Contributor

ToucheSir commented Mar 12, 2023

I've been revisiting FluxML/Zygote.jl#1328 as part of a larger PR, and discovered this behaviour while running https://github.com/FluxML/Zygote.jl/blob/108e5a19d8fa7187f6eaece7a142c48d71dfd0d2/test/chainrules.jl#L275.

MWE:

julia> _, back = rrule(getindex, [1], 1)
(1, ChainRules.var"#getindex_pullback#1601"{Vector{Int64}, Tuple{Int64}, Tuple{NoTangent}}([1], (1,), (NoTangent(),)))

julia> gs = back(@not_implemented("test"))
(NoTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), NoTangent())

julia> unthunk(gs[2])
ERROR: MethodError: Cannot `convert` an object of type Bool to an object of type ChainRulesCore.NotImplemented
Closest candidates are:
  convert(::Type{T}, ::T) where T at Base.jl:61
  ChainRulesCore.NotImplemented(::Any, ::Any, ::Any) at ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/notimplemented.jl:30
Stacktrace:
 [1] fill!(dest::Vector{ChainRulesCore.NotImplemented}, x::Bool)
   @ Base ./array.jl:351
 [2] _setindex_zero(x::Vector{Int64}, dy::ChainRulesCore.NotImplemented, inds::Int64)
   @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/indexing.jl:104
 [3] ∇getindex(x::Vector{Int64}, dy::ChainRulesCore.NotImplemented, inds::Int64)
   @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/indexing.jl:88
 [4] (::ChainRules.var"#1603#1605"{Vector{Int64}, ChainRulesCore.NotImplemented, Tuple{Int64}})()
   @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/indexing.jl:73
 [5] unthunk
   @ ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:204 [inlined]
 [6] unthunk(x::InplaceableThunk{Thunk{ChainRules.var"#1603#1605"{Vector{Int64}, ChainRulesCore.NotImplemented, Tuple{Int64}}}, ChainRules.var"#1602#1604"{Vector{Int64}, ChainRulesCore.NotImplemented, Tuple{Int64}}})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:237
 [7] top-level scope
   @ REPL[11]:1

julia> _, back = rrule(getindex, [1], [1])
([1], ChainRules.var"#getindex_pullback#1601"{Vector{Int64}, Tuple{Vector{Int64}}, Tuple{NoTangent}}([1], ([1],), (NoTangent(),)))

julia> gs = back([NoTangent()])
(NoTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), NoTangent())

julia> unthunk(gs[2])
ERROR: MethodError: Cannot `convert` an object of type Bool to an object of type NoTangent
Closest candidates are:
  convert(::Type{T}, ::T) where T at Base.jl:61
Stacktrace:
 [1] fill!(dest::Vector{NoTangent}, x::Bool)
   @ Base ./array.jl:351
 [2] _setindex_zero(x::Vector{Int64}, dy::Vector{NoTangent}, inds::Vector{Int64})
   @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/indexing.jl:105
 [3] ∇getindex(x::Vector{Int64}, dy::Vector{NoTangent}, inds::Vector{Int64})
   @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/indexing.jl:88
 [4] (::ChainRules.var"#1603#1605"{Vector{Int64}, Vector{NoTangent}, Tuple{Vector{Int64}}})()
   @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/indexing.jl:73
 [5] unthunk
   @ ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:204 [inlined]
 [6] unthunk(x::InplaceableThunk{Thunk{ChainRules.var"#1603#1605"{Vector{Int64}, Vector{NoTangent}, Tuple{Vector{Int64}}}}, ChainRules.var"#1602#1604"{Vector{Int64}, Vector{NoTangent}, Tuple{Vector{Int64}}}})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:237
 [7] top-level scope
   @ REPL[14]:1

The lines at fault are

_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
. I would imagine considering x's eltype in the final array would be beneficial, but I'm not familiar enough with all the edge cases to be sure. Maybe the correct solution is to catch this at a higher level.

@oxinabox oxinabox added bug Something isn't working and removed bug Something isn't working labels Mar 13, 2023
@oxinabox
Copy link
Member

Maybe the correct solution is to catch this at a higher level.

Yes, this should be handled in the AD system before it calls pullback.
Because if the input to pullback is one of these types we know the output (it will be the same object) so we don't need to handle it in the rules.

@ToucheSir
Copy link
Contributor Author

If I'm not mistaken it'd have to be handled in ∇getindex? That's what I meant by higher level. Otherwise the AD system would have to know a priori that this particular pullback will blow up when it encounters a NotImplemented/is a no-op when it's passed an array of AbstractZero.

@oxinabox
Copy link
Member

Oh it is arrays containing Zeros/NotImplemented?
Yeah we should probably handle that.
Though if it only contains Zero at least, it should really be being simplified down to just Zero somewhere. (not sure where)

@ToucheSir
Copy link
Contributor Author

Collapsing arrays of Zeros seems reasonable. Is there a rule for what happens when that array contains both ZeroTangent and NoTangent? Would simplify_cotangents(x::Array{<:AbstractZero}) = ZeroTangent() (or = NoTangent()) be sufficient?

@oxinabox
Copy link
Member

oxinabox commented Mar 13, 2023

Is there a rule for what happens when that array contains both ZeroTangent and NoTangent?

Technically speaking we should keep them seperate.
In practice seeing a mix should be really rare.
And noone actually treats them differently.
Simplifying them all to NoTangent() is probably fine.
(NotIplemented dominates NoTangent dominates ZeroTangent)

@oxinabox
Copy link
Member

oxinabox commented May 8, 2023

I can't make a MWE for this, that isn't purely zeros (which should be removed before hitting the rrule)

E.g. following works

_, back3 = rrule(getindex, [10, 0, -1], :)
gs3 = back3([2.0, NoTangent(), (@not_implemented "test2")])
num, notan, not_imp = unthunk(gs3[2]) 
@test  num isa Real
@test iszero(notan)  # We don't care if this gets converted to a 0.0
@test not_imp isa NotImplemented

@ToucheSir
Copy link
Contributor Author

I will try to get back to this and the PR which spawned it this weekend. IIRC doing some types of collapsing made certain Zygote tests very unhappy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants