Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error during gradient calculation #399

Closed
AzamatB opened this issue Nov 12, 2019 · 12 comments · Fixed by AzamatB/Zygote.jl#1 · May be fixed by #489
Closed

Error during gradient calculation #399

AzamatB opened this issue Nov 12, 2019 · 12 comments · Fixed by AzamatB/Zygote.jl#1 · May be fixed by #489

Comments

@AzamatB
Copy link
Contributor

AzamatB commented Nov 12, 2019

I'm getting an error during gradient calculation and don't understand what is causing it.
My actual code, where this happens is quite large, but I was able to reduce it to the following MRE:

using Flux
using Zygote
using LinearAlgebra

mutable struct State
   context
   decoding
   prediction
end

m = (state = State(zeros(Float32, 8, 7), zeros(Float32, 8, 7), zeros(Float32, 4, 7)),
     listen = RNN(5, 8),
     attention_ϕ = Dense(8, 8),
     attention_ψ = Dense(8, 8),
     spell = RNN(20, 8),
     infer = Chain(Dense(16, 4), logsoftmax))

function f(xs::AbstractVector{<:AbstractMatrix})
   hs = m.listen.(xs)
   Hs = cat(hs...; dims=3)
   ψHs = m.attention_ψ.(hs)

   ŷs = map(1:length(xs)) do _
      ϕSᵢᵀ = permutedims(m.attention_ϕ(m.state.decoding))
      Eᵢs = diag.((ϕSᵢᵀ,) .* ψHs)
      αᵢs = softmax(hcat(Eᵢs...)')
      m.state.context = dropdims(sum(reshape(αᵢs, 1, 7, :) .* Hs; dims=3); dims=3)
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return sum(m.state.prediction)
   end
   return sum(ŷs)
end

xs = [rand(Float32, 5,7) for _  1:3]

θ = params(m)

Zygote.gradient(θ) do
   f(xs)
end

the error message I get is:

ERROR: MethodError: no method matching +(::Array{Array{Float32,2},1}, ::Tuple{Array{Float32,2},Array{Float32,2},Array{Float32,2}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:529
  +(::Array, ::Array...) at arraymath.jl:44
  +(::Array, ::SparseArrays.SparseMatrixCSC) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/SparseArrays/src/sparsematrix.jl:1647
  ...
Stacktrace:
 [1] accum(::Array{Array{Float32,2},1}, ::Tuple{Array{Float32,2},Array{Float32,2},Array{Float32,2}}) at /home/azamat/.julia/packages/Zygote/ycnjm/src/lib/lib.jl:8
 [2] (::typeof((f)))(::Float32) at ./untitled-6447fe34214b7f2f51d14ffb8e06011c:20
 [3] (::Zygote.var"#28#29"{typeof((f))})(::Float32) at /home/azamat/.julia/packages/Zygote/ycnjm/src/compiler/interface.jl:38
 [4] gradient(::Function, ::Array{Array{Float32,2},1}) at /home/azamat/.julia/packages/Zygote/ycnjm/src/compiler/interface.jl:47
 [5] top-level scope at untitled-6447fe34214b7f2f51d14ffb8e06011c:37

while f(xs) works.

I tried to reduce the example further, but the error starts to evade.

@AzamatB
Copy link
Contributor Author

AzamatB commented Nov 12, 2019

I'm willing to try to work on a PR that fixes this if someone can explain what is the problem here.

@findmyway
Copy link
Contributor

Hi @AzamatB ,

The reason is that the accum function falls back to the default implementation as the error message says.

To make your example work. You can remove the type constraint in the f and define xs = Tuple(rand(Float32, 5,7) for _ ∈ 1:3). I'm not sure if a PR is needed. Maybe someone else can answer it.

@AzamatB
Copy link
Contributor Author

AzamatB commented Nov 25, 2019

Hi @findmyway thank you for pointing out the reason for the error! This is very helpful!

Would replacing this method definition with something like

accum(x::Union{Tuple,AbstractVector}, y::Union{Tuple,AbstractVector}) = accum.(x, y)

solve the issue here then?

@findmyway
Copy link
Contributor

I guess so. But I'm not sure is there any side effect.

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 11, 2020

Looks like the line Hs = cat(hs...; dims=3) is at blame here. For example replacing it with

h = first(hs)
Hsbuffer = Buffer(h, size(h,1), size(h,2), length(hs))
for k  eachindex(hs)
   Hsbuffer[:,:,k] = hs[k]
end
Hs = copy(Hsbuffer)

works around this issue.
I tried to reduce the example above further using this info, but the error keeps evading me when I try.
Hopefully, this helps to identify the culprit

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 28, 2020

Took me a while to hunt it down, but here is the minimum reproducible example:

using Zygote
julia> gradient(x -> sum(hcat(x...) * sum(x)), [rand(4,2), rand(4,2)])
ERROR: MethodError: no method matching +(::FillArrays.Fill{Array{Float64,2},1,Tuple{Base.OneTo{Int64}}}, ::Tuple{Array{Float64,2},Array{Float64,2}})

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 28, 2020

Further reduced to:

using Zygote
julia> gradient(x -> sum(hcat(x...) * sum(x)), [rand(2), rand(2)])
ERROR: MethodError: no method matching +(::FillArrays.Fill{Array{Float64,1},1,Tuple{Base.OneTo{Int64}}}, ::Tuple{Array{Float64,1},Array{Float64,1}})

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 28, 2020

Further reduced to:

using Zygote
julia> gradient(x -> vcat(x...)'x, rand(2))
ERROR: MethodError: no method matching +(::Array{Float64,1}, ::Tuple{Float64,Float64})

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 28, 2020

Any ideas on what is the problem here now? Happy to prepare PR fixing this, if I get an understanding of the underlying issue.

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 30, 2020

Pasting @mohamed82008's comment from slack:

The problem appears to be with splatting:

julia> gradient(x -> +(x...), rand(2))
((1.0, 1.0),)

The output is a tuple not a vector.

While

julia> gradient(x -> sum(x), rand(2))
([1.0, 1.0],)

returns a vector. The first one should return a vector too.

@mcabbott
Copy link
Member

Great to fix the splat bug. Here's another possible solution to the cat problem:

using LazyStack

function f(xs::AbstractVector{<:AbstractMatrix})
   hs = m.listen.(xs)
   Hs = stack(hs)
   ψHs = m.attention_ψ.(hs)
...

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 14, 2020

Closing this in favor of #599.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants