Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No method matching for mapreduce? #539

Closed
Red-Portal opened this issue Oct 11, 2021 · 17 comments
Closed

No method matching for mapreduce? #539

Red-Portal opened this issue Oct 11, 2021 · 17 comments
Assignees

Comments

@Red-Portal
Copy link

Hi,

I'm having the following issue

ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Float64}})
Closest candidates are:
  (::ChainRulesCore.ProjectTo{var"#s12", D} where {var"#s12"<:Real, D<:NamedTuple})(::Complex) at /home/msca8h/.julia/packages/ChainRulesCore/1L9My/src/projection.jl:179
  (::ChainRulesCore.ProjectTo)(::ChainRulesCore.Thunk) at /home/msca8h/.julia/packages/ChainRulesCore/1L9My/src/projection.jl:124
  (::ChainRulesCore.ProjectTo{T, D} where D<:NamedTuple)(::AbstractFloat) where T<:AbstractFloat at /home/msca8h/.julia/packages/ChainRulesCore/1L9My/src/projection.jl:167
  ...
Stacktrace:
  [1] (::ChainRulesCore.ProjectTo{Ref, NamedTuple{(:type, :x), Tuple{DataType, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1L9My/src/projection.jl:275
  [2] #56
    @ ~/.julia/packages/ChainRulesCore/1L9My/src/projection.jl:230 [inlined]
  [3] #4
    @ ./generator.jl:36 [inlined]
  [4] iterate
    @ ./generator.jl:47 [inlined]
  [5] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{ChainRulesCore.ProjectTo{Ref, NamedTuple{(:type, :x), Tuple{DataType, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Vector{Base.RefValue{Any}}}}, Base.var"#4#5"{ChainRulesCore.var"#56#57"}})
    @ Base ./array.jl:681
  [6] map
    @ ./abstractarray.jl:2383 [inlined]
  [7] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{Ref, NamedTuple{(:type, :x), Tuple{DataType, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Base.RefValue{Any}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1L9My/src/projection.jl:230
  [8] (::ChainRules.var"#sum_pullback#1375"{Colon, typeof(getindex), ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{Ref, NamedTuple{(:type, :x), Tuple{DataType, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Zygote.var"#ad_pullback#45"{Tuple{typeof(getindex), Base.RefValue{Float64}}, typeof(∂(getindex))}}})(ȳ::Float64)
    @ ChainRules ~/.julia/packages/ChainRules/RyXef/src/rulesets/Base/mapreduce.jl:88
  [9] ZBack
    @ ~/.julia/packages/Zygote/Lw5Kf/src/compiler/chainrules.jl:168 [inlined]
 [10] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/threadsafe.jl:25 [inlined]
 [11] (::typeof(∂(getlogp)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/interface2

which pos up when differentiating through the following Turing model

ard_kernel(σ², ℓ) = 
    σ²*KernelFunctions.TransformedKernel(
        KernelFunctions.Matern52Kernel(),
        KernelFunctions.ARDTransform(1 ./ ℓ))

Turing.@model function logisticgp(X, y, jitter=1e-6)
    n_features = size(X, 1)

    logα  ~ Normal(0, 1)
    logσ  ~ Normal(0, 1)
    logℓ  ~ MvNormal(n_features, 1)

    α²     = exp(logα*2)
    σ²     = exp(logσ*2)
    ℓ      = exp.(logℓ)
    kernel = ard_kernel(α², ℓ) 
    K      = KernelFunctions.kernelmatrix(kernel, X)
    K     += I*(σ² + jitter)

    f  ~ MvNormal(zeros(size(X,2)), K)
    y .~ Turing.BernoulliLogit.(f)
end

Is this issue related to ChainRules.jl? Not entirely sure.

@mcabbott
Copy link
Member

Might be a bug here. Could you extend your example to produce it? I.e. presumably it needs using Turing at the top, and something calling this logisticgp after it's defined.

@Red-Portal
Copy link
Author

Red-Portal commented Oct 11, 2021

using Turing
using KernelFunctions
using Zygote
using LinearAlgebra

ard_kernel(σ², ℓ) = 
    σ²*KernelFunctions.TransformedKernel(
        KernelFunctions.Matern52Kernel(),
        KernelFunctions.ARDTransform(1 ./ ℓ))

Turing.@model function logisticgp(X, y, jitter=1e-6)
    n_features = size(X, 1)

    logα  ~ Normal(0, 1)
    logσ  ~ Normal(0, 1)
    logℓ  ~ MvNormal(n_features, 1)

    α²     = exp(logα*2)
    σ²     = exp(logσ*2)
    ℓ      = exp.(logℓ)
    kernel = ard_kernel(α², ℓ) 
    K      = KernelFunctions.kernelmatrix(kernel, X)
    K     += I*(σ² + jitter)

    f  ~ MvNormal(zeros(size(X,2)), K)
    y .~ Turing.BernoulliLogit.(f)
end
Turing.setadbackend(:zygote)
model = logisticgp(randn(10, 100), rand(100) .> 0.5)
sample(model, NUTS(), 10)

Hi @mcabbott ,
Here's a self-contained reproduction.

@mcabbott
Copy link
Member

mcabbott commented Oct 11, 2021

Ok, can reproduce (on 1.6, but not can't install on 1.7?). And ForwardDiff looks like it runs, with no other changes:

julia> using ForwardDiff, LinearAlgebra

julia> Turing.setadbackend(:forwarddiff);

julia> sample(model, NUTS(), 10)
┌ Info: Found initial step size
└   ϵ = 0.2
Sampling 100%|██████████████████████████████████████████████████████████████| Time: 0:00:39
Chains MCMC chain (10×124×1 Array{Float64, 3}):

Iterations        = 6:1:15
Number of chains  = 1
Samples per chain = 10
Wall duration     = 40.47 seconds

But have no idea what's going on yet. The error still happens with JuliaDiff/ChainRulesCore.jl#488 .

This must come from x::Vector{Ref{<:Real}}. It's getting another vector of Ref.

[7] (::ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ProjectTo{Ref, ...)(dx::Vector{Base.RefValue{Any}})

But if you try that in isolation, it seems fine:

julia> using ChainRulesCore

julia> ProjectTo([Ref(1)])([Ref{Any}(1)])
1-element Vector{Tangent{Base.RefValue{Int64}, NamedTuple{(:x,), Tuple{Float64}}}}:
 Tangent{Base.RefValue{Int64}}(x = 1.0,)

The rule for sum(f, xs) is involved: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L88
Also it's closing over getindex, or something?

[8] (::ChainRules.var"#sum_pullback#1375"{Colon, typeof(getindex), 

@devmotion
Copy link
Member

but not can't install on 1.7?

This is known (and expected) due to Libtask_jll: TuringLang/Turing.jl#1713 (comment)

@Red-Portal
Copy link
Author

Red-Portal commented Oct 11, 2021

I have a quick question though, the model in question is currently crazy slow with a 60 dimension dataset with like 150 data points. It's quite surprising that a few tens of data points can make such a nightmare of a difference. Would Zygote help in this case? Or is the performance of ForwardDiff.jl the best we can get?

@devmotion
Copy link
Member

First of all, I would recommend that you use AbstractGPs or, usually easier and more user-friendly, one of the downstream packages such as Stheno instead of working with KernelFunctions directly. It seems there are some efficiency improvements possible though. E.g., I would suggest to use

kernel(α², logℓ, σ²) = α² * (Matern52Kernel()  ARDTransform(@. exp(-logℓ))) + σ² * WhiteKernel()

@model function logisticgp(X, y, jitter=1e-6)
    n_features, n_observations = size(X)

    logα  ~ Normal(0, 1)
    logσ  ~ Normal(0, 1)
    # MvNormal(::Int, ::Real) is deprecated
    logℓ  ~ MvNormal(zeros(n_features), I) # or a bit more efficient: logℓ  ~ MvNormal(Zeros(n_features), I)

    α² = exp(logα*2)
    σ² = exp(logσ*2)
    K = kernelmatrix(kernel(α², logℓ, σ² + jitter), X)
...
end

Due to all the matrix operations involved here and since logℓ is 60 dimensional and f is 150 dimensional, I would assume that reverse-mode is faster. Have you tried ReverseDiff, also with a compiled tape (https://turing.ml/dev/docs/using-turing/autodiff)?

@Red-Portal
Copy link
Author

@devmotion Thanks for the tips. I'm indeed using AbstractGPs, I swapped it for MvNormal for simplicity. Although the computation time isn't that different. Interestingly enough, ReverseDiff seems to be actually slower (without compiled tapes). I'll try the compiled tapes and see if that helps.

@mcabbott
Copy link
Member

I believe that https://discourse.julialang.org/t/ffts-in-probabilistic-models/69775/2 might be a similar issue. You can trigger that like this:

julia> using ChainRulesCore

julia> ProjectTo(Ref(1))(Ref{Any}(2))  # this is what's expected
Tangent{Base.RefValue{Int64}}(x = 2.0,)

julia> ProjectTo(Ref(3))(Ref{Any}((x=4,)))  # both a Ref and a NamedTuple
ERROR: MethodError: no method matching (::ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Int64}})
  ...
Stacktrace:
 [1] (::ProjectTo{Ref, NamedTuple{(:type, :x), Tuple{DataType, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1L9My/src/projection.jl:275

And for mutable structs, Zygote wraps the NamedTuple it would make for a struct in a Ref. Which is one more layer than expected, and so this fails:

julia> pullback(x -> x[]^2, Ref(3))[2](1)  # no projection
(Base.RefValue{Any}((x = 6,)),)

julia> gradient(x -> x[]^2, Ref(3))  # with projection
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Int64}})

@oxinabox
Copy link
Member

oxinabox commented Oct 15, 2021

And for mutable structs, Zygote wraps the NamedTuple it would make for a struct in a Ref. Which is one more layer than expected, and so this fails:

This needs to be handled in:
https://github.com/FluxML/Zygote.jl/blob/5887e46bf6280e3608dbed2e27f2229fa1456087/src/compiler/chainrules.jl#L124
to unwrap the Ref.
I am pretty sure it is a leftover from when Zygote kinda supported mutation, and doesn't do anything now.
(Or even if it is not, it still needs to be fixed there)

@mcabbott
Copy link
Member

Done. @Red-Portal see if this is solved by Zygote.0.6.28

@Red-Portal
Copy link
Author

Hi @mcabbott

Unfortunately, the example still doesn't run.
Here's my up-to-date version of Zygote.

(@v1.6) pkg> status Zygote
      Status `~/.julia/environments/v1.6/Project.toml`
  [e88e6eb3] Zygote v0.6.28

I'm still on Julia 1.6

@mcabbott
Copy link
Member

Do you get the same error or a different one?

@Red-Portal
Copy link
Author

@mcabbott seems more or less identical. Here's the error. The line numbers are slightly different.

julia> sample(model, NUTS(), 10)
Sampling 100%|██████████████████████████████████████████| Time: 0:01:42
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Int64}})
Closest candidates are:
  (::ChainRulesCore.ProjectTo{var"#s12", D} where {var"#s12"<:Real, D<:NamedTuple})(::Complex) at /home/msca8h/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:186
  (::ChainRulesCore.ProjectTo{T, D} where D<:NamedTuple)(::AbstractFloat) where T<:AbstractFloat at /home/msca8h/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:170
  (::ChainRulesCore.ProjectTo)(::ChainRulesCore.Thunk) at /home/msca8h/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:124
  ...
Stacktrace:
  [1] (::ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Float64}, T} where T, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:284
  [2] #56
    @ ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:237 [inlined]
  [3] #4
    @ ./generator.jl:36 [inlined]
  [4] iterate
    @ ./generator.jl:47 [inlined]
  [5] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Float64}, T} where T, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Vector{Base.RefValue{Any}}}}, Base.var"#4#5"{ChainRulesCore.var"#56#57"}})
    @ Base ./array.jl:681
  [6] map
    @ ./abstractarray.jl:2383 [inlined]
  [7] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Float64}, T} where T, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Base.RefValue{Any}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:237
  [8] (::ChainRules.var"#sum_pullback#1375"{Colon, typeof(getindex), ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Float64}, T} where T, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Zygote.var"#ad_pullback#45"{Tuple{typeof(getindex), Base.RefValue{Float64}}, typeof(∂(getindex))}}})(ȳ::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/RyXef/src/rulesets/Base/mapreduce.jl:88
  [9] ZBack
    @ ~/.julia/packages/Zygote/fDJjj/src/compiler/chainrules.jl:170 [inlined]
 [10] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/threadsafe.jl:25 [inlined]
 [11] (::typeof(∂(getlogp)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:439 [inlined]
 [13] (::typeof(∂(evaluate_threadsafe)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:391 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:383 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [18] #203
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:203 [inlined]
 [19] #1733#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [20] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:396 [inlined]
 [21] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/Turing/uMQmD/src/core/ad.jl:165 [inlined]
 [23] (::typeof(∂(λ)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#50#51"{typeof(∂(λ))})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface.jl:41
 [25] gradient_logp(backend::Turing.Core.ZygoteAD, θ::Vector{Float64}, vi::DynamicPPL.TypedVarInfo{NamedTuple{(:logα, :logσ, :logℓ, :f), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:logα, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:logα, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:logσ, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:logσ, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:logℓ, Tuple{}}, Int64}, Vector{ZeroMeanIsoNormal{Tuple{Base.OneTo{Int64}}}}, Vector{AbstractPPL.VarName{:logℓ, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:f, Tuple{}}, Int64}, Vector{FullNormal}, Vector{AbstractPPL.VarName{:f, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, model::DynamicPPL.Model{typeof(logisticgp), (:X, :y, :jitter), (:jitter,), (), Tuple{Matrix{Float64}, BitVector, Float64}, Tuple{Float64}, DynamicPPL.DefaultContext}, sampler::DynamicPPL.Sampler{NUTS{Turing.Core.ZygoteAD, (), AdvancedHMC.DiagEuclideanMetric}}, context::DynamicPPL.DefaultContext)
    @ Turing.Core ~/.julia/packages/Turing/uMQmD/src/core/ad.jl:171
 [26] gradient_logp (repeats 2 times)
    @ ~/.julia/packages/Turing/uMQmD/src/core/ad.jl:83 [inlined]
 [27] ∂logπ∂θ
    @ ~/.julia/packages/Turing/uMQmD/src/inference/hmc.jl:433 [inlined]
 [28] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:31 [inlined]
 [29] phasepoint
    @ ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:76 [inlined]

@mcabbott
Copy link
Member

Ok, thanks for checking! I got everything loaded and get the same now, on 1.6.

The line [7] (::ChainRulesCore.ProjectTo ... )(dx::Vector{Base.RefValue{Any}}) looks like it may have something to do with an array of mutable structs. But this doesn't seem to help:

@eval Zygote @inline wrap_chainrules_input(xs::AbstractArray{<:Ref}) = wrap_chainrules_input.(xs)
Zygote.refresh()

Maybe I already said this but a few lines down it has sum_pullback which is the one with rrule_via_ad here; maybe something weird is happening in how the RuleConfig{>:HasReverseMode} story moves things back & forward?
https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L88

@mcabbott
Copy link
Member

mcabbott commented Oct 15, 2021

Here's a more minimal example:

julia> using Zygote

julia> y, back = pullback(x -> sum(getindex, x), Ref.(1:3))
(6, Zygote.var"#52#53"{typeof((#11))}(∂(#11)))

julia> back(1)
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Int64}})
  ...
Stacktrace:
  [1] (::ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Int64}}, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:284
  [2] #56
    @ ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:237 [inlined]
  [3] #4
    @ ./generator.jl:36 [inlined]
  [4] iterate
    @ ./generator.jl:47 [inlined]
  [5] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Int64}}, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Vector{Base.RefValue{Any}}}}, Base.var"#4#5"{ChainRulesCore.var"#56#57"}})
    @ Base ./array.jl:710
  [6] map
    @ ./abstractarray.jl:2860 [inlined]
  [7] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Int64}}, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Base.RefValue{Any}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:237
  [8] (::ChainRules.var"#sum_pullback#1375"{Colon, typeof(getindex), ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Int64}}, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Zygote.var"#ad_pullback#47"{Tuple{typeof(getindex), Base.RefValue{Int64}}, typeof((getindex))}}})(ȳ::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/RyXef/src/rulesets/Base/mapreduce.jl:88
  [9] ZBack
...

(@v1.7) pkg> st Zygote
      Status `~/.julia/environments/v1.7/Project.toml`
  [e88e6eb3] Zygote v0.6.28

And a version with broadcasting, without the sum(f, xs) rule. This one is fixed by the suggestion above (although the sum(f, xs) version is not):

julia> y, back = pullback(x -> sum(getindex.(x)), Ref.(1:3))
(6, Zygote.var"#52#53"{typeof((#15))}(∂(#15)))

julia> back(1)
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Int64}})
  ...
Stacktrace:
  [1] (::ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Int64}}, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:284
...
  [7] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{ChainRulesCore.Tangent{Base.RefValue{Int64}}, NamedTuple{(:x,), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Base.RefValue{Any}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/bxKCw/src/projection.jl:237
  [8] _project
    @ ~/.julia/packages/Zygote/fDJjj/src/compiler/chainrules.jl:142 [inlined]
  [9] unbroadcast(x::Vector{Base.RefValue{Int64}}, x̄::Vector{Base.RefValue{Any}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:51

julia> @eval Zygote @inline wrap_chainrules_input(xs::AbstractArray{<:Ref}) = wrap_chainrules_input.(xs);

julia> Zygote.refresh()

julia> y, back = pullback(x -> sum(getindex.(x)), Ref.(1:3))  # now try again, same thing
(6, Zygote.var"#52#53"{typeof((#21))}(∂(#21)))

julia> back(1)[1]  # ChainRules types escape, but otherwise OK.
3-element Vector{ChainRulesCore.Tangent{Base.RefValue{Int64}, NamedTuple{(:x,), Tuple{Float64}}}}:
 Tangent{Base.RefValue{Int64}}(x = 1.0,)
 Tangent{Base.RefValue{Int64}}(x = 1.0,)
 Tangent{Base.RefValue{Int64}}(x = 1.0,)

@mcabbott
Copy link
Member

Example from #539 (comment) is still broken on latest Zygote + 1.6, although at least the error is different now:

julia> sample(model, NUTS(), 10)
Sampling 100%|██████████████████████████████████████████████████████████████| Time: 0:01:07
ERROR: MethodError: no method matching +(::NamedTuple{(:x,), Tuple{Float64}}, ::ChainRulesCore.Tangent{Base.RefValue{Float64}, NamedTuple{(:x,), Tuple{Float64}}})
  ...
Stacktrace:
  [1] accum(x::NamedTuple{(:x,), Tuple{Float64}}, y::ChainRulesCore.Tangent{Base.RefValue{Float64}, NamedTuple{(:x,), Tuple{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/lib/lib.jl:17
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:648 [inlined]
...
  [6] materialize
    @ ./broadcast.jl:883 [inlined]
  [7] accum(x::Vector{NamedTuple{(:x,), Tuple{Float64}}}, ys::Vector{ChainRulesCore.AbstractTangent})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/lib/lib.jl:25
  [8] macro expansion
    @ ~/.julia/packages/Zygote/rv6db/src/lib/lib.jl:27 [inlined]
  [9] accum(x::NamedTuple{(:varinfo, :logps), Tuple{Nothing, Vector{NamedTuple{(:x,), Tuple{Float64}}}}}, y::NamedTuple{(:varinfo, :logps), Tuple{NamedTuple{(:metadata, :logp, :num_produce), Tuple{NamedTuple{(:logα, :logσ, :logℓ, :f), NTuple{4, NamedTuple{(:idcs, :vns, :ranges, :vals, :dists, :gids, :orders, :flags), Tuple{Nothing, Nothing, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}}}, Nothing, Nothing}}, Vector{ChainRulesCore.AbstractTangent}}})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/lib/lib.jl:27
 [10] getindex
    @ ./tuple.jl:29 [inlined]
 [11] gradindex
    @ ~/.julia/packages/Zygote/rv6db/src/compiler/reverse.jl:12 [inlined]
 [12] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:438 [inlined]
 [13] (::typeof(∂(evaluate_threadsafe)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:391 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:383 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [18] #203
    @ ~/.julia/packages/Zygote/rv6db/src/lib/lib.jl:203 [inlined]
 [19] #1734#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [20] Pullback
    @ ~/.julia/packages/DynamicPPL/RcfQU/src/model.jl:396 [inlined]
 [21] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/Turing/uMQmD/src/core/ad.jl:165 [inlined]
 [23] (::typeof(∂(λ)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#50#51"{typeof(∂(λ))})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface.jl:41
 [25] gradient_logp(backend::Turing.Core.ZygoteAD, θ::Vector{Float64}, vi::DynamicPPL.TypedVarInfo{NamedTuple{(:logα, :logσ, :logℓ, :f), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:logα, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:logα, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:logσ, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:logσ, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:logℓ, Tuple{}}, Int64}, Vector{ZeroMeanIsoNormal{Tuple{Base.OneTo{Int64}}}}, Vector{AbstractPPL.VarName{:logℓ, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:f, Tuple{}}, Int64}, Vector{FullNormal}, Vector{AbstractPPL.VarName{:f, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, model::DynamicPPL.Model{typeof(logisticgp), (:X, :y, :jitter), (:jitter,), (), Tuple{Matrix{Float64}, BitVector, Float64}, Tuple{Float64}, DynamicPPL.DefaultContext}, sampler::DynamicPPL.Sampler{NUTS{Turing.Core.ZygoteAD, (), AdvancedHMC.DiagEuclideanMetric}}, context::DynamicPPL.DefaultContext)
    @ Turing.Core ~/.julia/packages/Turing/uMQmD/src/core/ad.jl:171
 [26] gradient_logp (repeats 2 times)
    @ ~/.julia/packages/Turing/uMQmD/src/core/ad.jl:83 [inlined]
...

(jl_kb8biz) pkg> st Zygote
      Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_kb8biz/Project.toml`
  [e88e6eb3] Zygote v0.6.29

@mcabbott
Copy link
Member

This now seems to be fixed, with the versions shown:

julia> sample(model, NUTS(), 10)
┌ Info: Found initial step size
└   ϵ = 0.8
Sampling 100%|██████████████████████████████████████████████████████████████| Time: 0:01:15
Chains MCMC chain (10×124×1 Array{Float64, 3}):

Iterations        = 6:1:15
Number of chains  = 1
Samples per chain = 10
Wall duration     = 76.08 seconds

(jl_dAz9hk) pkg> st
      Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_dAz9hk/Project.toml`
  [ec8451be] KernelFunctions v0.10.26
  [fce5fe82] Turing v0.19.0
  [e88e6eb3] Zygote v0.6.30

julia> versioninfo()
Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.6.0)
  CPU: Apple M1
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, westmere)
Environment:
  JULIA_NUM_THREADS = 4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants