Skip to content

Zygote and ChainRules are broken #95

@torfjelde

Description

@torfjelde

Currently

using ComponentArrays, Zygote

f(a, sym) = getproperty(a, sym)
ca = ComponentArray(x = [1, 2], θ = [1.0, 100.0], deg = 2)
y, dy = Zygote.pullback(f, ca, Val{:x}())
dy(similar(y) .= 1)

fails with

BoundsError: attempt to access Tuple{Nothing, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(s = 1, m = 2:101)}}}} at index [3]

Stacktrace:
 [1] getindex
   @ ./tuple.jl:29 [inlined]
 [2] gradindex(x::Tuple{Nothing, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(s = 1, m = 2:101)}}}}, i::Int64)
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/reverse.jl:12
 [3] Pullback
   @ ./In[200]:1 [inlined]
 [4] (::typeof((f)))(Δ::Vector{Float64})
   @ Zygote ./compiler/interface2.jl:0
 [5] (::Zygote.var"#41#42"{typeof((_f))})(Δ::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [6] top-level scope
   @ In[200]:3
 [7] eval
   @ ./boot.jl:360 [inlined]
 [8] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base ./loading.jl:1094

AFAIK the following occurs:

  1. The rules in src/compat/zygote.jl is never hit, but instead we reach the rules in src/compat/chainrulescore.jl.
    • Interestingly Zygote.pullback(getproperty, ca, Val{:x}()) actually works. I'm guessing this is because this hits the adjoint defined for Zygote.literal_property while in every other case you hit ChainRules's impl.
  2. The rules defined in src/compat/chainrulescore.jl seem incorrect, as they're missing the NoTangent for the key used in getproperty, hence the getindex error above.

AFAIK this also means that src/compat/zygote.jl can just be removed.

Also, is there a reason why ChainRulesCore.jl shouldn't be a dependency of ComponentArrays.jl? This would allow specifying compat-bounds and thus avoiding impls "silently" going out of date, as seems to have happened above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions