Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD error due to interaction of missing gradient of besselk and struct with vector #1204

Closed
st-- opened this issue Apr 13, 2022 · 5 comments · Fixed by #1205
Closed

AD error due to interaction of missing gradient of besselk and struct with vector #1204

st-- opened this issue Apr 13, 2022 · 5 comments · Fixed by #1205

Comments

@st--
Copy link
Contributor

st-- commented Apr 13, 2022

Minimal reproducible example - note that Foo errors, whereas Bar works.

using Zygote
using SpecialFunctions: besselk
using Test

struct Foo{Tv<:Real}
    v::Vector{Tv}

    function Foo(v::Real)
        return new{typeof(v)}([v])
    end
end

struct Bar{Tv<:Real}
    v::Tv
end

function check_ad(k)
    return only(Zygote.gradient(x -> besselk(only(k.v), x), 2.1))
end

@test_throws MethodError check_ad(Foo(1.5))
@test check_ad(Bar(1.5)) isa Real
@devmotion
Copy link
Collaborator

devmotion commented Apr 13, 2022

As mentioned in JuliaGaussianProcesses/KernelFunctions.jl#452, the problem seems to be that NotImplemented ends up in

∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
but is not handled there.

Minimal example:

julia> using ChainRulesCore

julia> f(x) = x;

julia> @scalar_rule f(x) @not_implemented(":(")

julia> using Zygote

julia> Zygote.gradient(f, 0.1)
(NotImplemented(Main, #= REPL[5]:1 =#, :(),)

julia> Zygote.gradient(x -> f(only(x)), 0.1)
(NotImplemented(Main, #= REPL[5]:1 =#, :(),)

julia> Zygote.gradient(x -> f(only(x)), (0.1,))
((NotImplemented(Main, #= REPL[5]:1 =#, :(),),)

julia> Zygote.gradient(x -> f(only(x)), [0.1])
ERROR: MethodError: no method matching Zygote.OneElement(::ChainRulesCore.NotImplemented, ::Tuple{Int64}, ::Tuple{Base.OneTo{Int64}})
Closest candidates are:
  Zygote.OneElement(::T, ::I, ::A) where {N, T<:Number, I<:Tuple{Vararg{Int64, N}}, A<:Tuple{Vararg{AbstractUnitRange, N}}} at ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:53
Stacktrace:
  [1] (::Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}})(dy::ChainRulesCore.NotImplemented)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:32
  [2] (::Zygote.var"#2300#back#429"{Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}}})(Δ::ChainRulesCore.NotImplemented)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [3] Pullback (repeats 2 times)
    @ ./array.jl:835 [inlined]
  [4] (::typeof((iterate)))(Δ::Tuple{ChainRulesCore.NotImplemented, Nothing})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [5] Pullback
    @ ./iterators.jl:1352 [inlined]
  [6] Pullback
    @ ./REPL[10]:1 [inlined]
  [7] (::typeof((#8)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#56#57"{typeof((#8))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
 [10] top-level scope
    @ REPL[10]:1

This is also highlighted by the fact that it's not only specific but a general issue with getindex:

julia> Zygote.gradient(x -> f(x[1]), [0.1])
ERROR: MethodError: no method matching Zygote.OneElement(::ChainRulesCore.NotImplemented, ::Tuple{Int64}, ::Tuple{Base.OneTo{Int64}})
Closest candidates are:
  Zygote.OneElement(::T, ::I, ::A) where {N, T<:Number, I<:Tuple{Vararg{Int64, N}}, A<:Tuple{Vararg{AbstractUnitRange, N}}} at ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:53
Stacktrace:
 [1] (::Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}})(dy::ChainRulesCore.NotImplemented)
   @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:32
 [2] (::Zygote.var"#2300#back#429"{Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}}})(Δ::ChainRulesCore.NotImplemented)
   @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [3] Pullback
   @ ~/.julia/packages/Zygote/H6vD3/src/tools/builtins.jl:15 [inlined]
 [4] (::typeof((literal_getindex)))(Δ::ChainRulesCore.NotImplemented)
   @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [5] Pullback
   @ ./REPL[11]:1 [inlined]
 [6] (::typeof((#10)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [7] (::Zygote.var"#56#57"{typeof((#10))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
 [8] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
 [9] top-level scope
   @ REPL[11]:1

@devmotion
Copy link
Collaborator

This definition

julia> @inline Zygote.wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing

fixes the issue:

julia> Zygote.gradient(f, 0.1)
(nothing,)

julia> Zygote.gradient(x -> f(only(x)), 0.1)
(nothing,)

julia> Zygote.gradient(x -> f(only(x)), (0.1,))
(nothing,)

julia> Zygote.gradient(x -> f(only(x)), [0.1])
(nothing,)

julia> Zygote.gradient(x -> f(x[1]), [0.1])
(nothing,)

Clearly, the information about not-implemented derivatives is lost but I guess this is the Zygote way of dealing with special tangent types such as AbstractZero anyway (

@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
). Since the (arguably suboptimal) nothing design is so fundamental for Zygote (at least currently), it seems this is also the easiest way to fix other similar errors without having to deal with NotImplemented in all pullbacks.

@gaurav-arya
Copy link

Treating NotImplemented identically to AbstractZero will lead to incorrect gradients whenever a gradient is taken WRT something whose gradient is not implemented. In the above examples, we get nothing, but we could easily get e.g. 1.0 instead of 2.0 if part of the gradient contribution is correctly tracked and part of the contribution hits a not implemented. Perhaps it's very hard to handle them in a better way, but it certainly seems like a giant footgun that should be noted somewhere if not already:)

@ToucheSir
Copy link
Member

This is far from the only example of the lossy Zygote -> ChainRules (-> Zygote) conversion possibly causing issues for subsequent operations. Unfortunately without doing at least part of #603, the status quo is probably going to stay the same for the foreseeable future.

@oxinabox
Copy link
Member

oxinabox commented Mar 1, 2023

this one is not just lossy, its actively doing something which is normally incorrect.
It is only by chance that this is sometimes the correct behavour, depending on what the correct implementation is or on the thing being un-used

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants