Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assignment to multiple arrays is not differentiable on GPU since Zygote.jl 0.6.67 #1470

Closed
bicycle1885 opened this issue Nov 8, 2023 · 5 comments · Fixed by JuliaDiff/ChainRules.jl#760
Labels
bug Something isn't working CUDA All things GPU discussion help wanted Extra attention is needed

Comments

@bicycle1885
Copy link

I found the following code doesn't work.

using CUDA, Zygote

function f(x)
    a, b = [x[1:4], x[5:8]]
    sum(a + b)
end

x = cu(randn(8))
@show Zygote.gradient(f, x)
ERROR: LoadError: MethodError: no method matching parent(::Type{SubArray{Union{ChainRulesCore.ZeroTangent, CuVector{Float32, CUDA.Mem.DeviceBuffer}, DenseCuVector{Float32, CUDA.Mem.DeviceBuffer}}, 0, Vector{Union{ChainRulesCore.ZeroTangent, CuVector{Float32, CUDA.Mem.DeviceBuffer}, DenseCuVector{Float32, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64}, true}})

Closest candidates are:
  parent(::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/adjtrans.jl:341
  parent(::Union{LinearAlgebra.LowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitLowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitUpperTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UpperTriangular{T, S} where S<:AbstractMatrix{T}} where T)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/triangular.jl:164
  parent(::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/symmetric.jl:275
  ...

Stacktrace:
  [1] backend(#unused#::Type{SubArray{Union{ChainRulesCore.ZeroTangent, CuVector{Float32, CUDA.Mem.DeviceBuffer}, DenseCuVector{Float32, CUDA.Mem.DeviceBuffer}}, 0, Vector{Union{ChainRulesCore.ZeroTangent, CuVector{Float32, CUDA.Mem.DeviceBuffer}, DenseCuVector{Float32, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64}, true}})
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:151
  [2] backend(x::SubArray{Union{ChainRulesCore.ZeroTangent, CuVector{Float32, CUDA.Mem.DeviceBuffer}, DenseCuVector{Float32, CUDA.Mem.DeviceBuffer}}, 0, Vector{Union{ChainRulesCore.ZeroTangent, CuVector{Float32, CUDA.Mem.DeviceBuffer}, DenseCuVector{Float32, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64}, true})
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:149
  [3] _copyto!
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:70 [inlined]
  [4] materialize!
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:46 [inlined]
  [5] materialize!
    @ ./broadcast.jl:881 [inlined]
  [6] ∇getindex!(dx::Vector{Union{ChainRulesCore.ZeroTangent, CuVector{Float32, CUDA.Mem.DeviceBuffer}, DenseCuVector{Float32, CUDA.Mem.DeviceBuffer}}}, dy::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:147
  [7] ∇getindex(x::Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, dy::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:89
  [8] (::ChainRules.var"#1582#1584"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}})()
    @ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:69
  [9] unthunk
    @ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
 [10] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1582#1584"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}}, ChainRules.var"#1581#1583"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:237
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
 [12] map
    @ ./tuple.jl:275 [inlined]
 [13] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:111 [inlined]
 [14] ZBack
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
 [15] Pullback
    @ ./tuple.jl:89 [inlined]
 [16] Pullback
    @ ~/.julia/packages/Zygote/YYT6v/src/tools/builtins.jl:14 [inlined]
 [17] (::Zygote.Pullback{Tuple{typeof(Zygote.literal_indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Val{2}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Int64, Int64}, Tuple{Zygote.ZBack{Zygote.var"#plus_pullback#345"{Tuple{Int64, Int64}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}}})(Δ::Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/workspace/MLUtils.jl/tmp/diff.jl:4 [inlined]
 [19] (::Zygote.Pullback{Tuple{typeof(f), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Zygote.literal_indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Int64, Int64}, Tuple{Zygote.ZBack{Zygote.var"#plus_pullback#345"{Tuple{Int64, Int64}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}}}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Val{2}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Int64, Int64}, Tuple{Zygote.ZBack{Zygote.var"#plus_pullback#345"{Tuple{Int64, Int64}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}}}, Zygote.var"#4197#back#1441"{Zygote.var"#1437#1440"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3589#back#1082"{Zygote.var"#1078#1081"}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1353"{2, Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(f), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Zygote.literal_indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Int64, Int64}, Tuple{Zygote.ZBack{Zygote.var"#plus_pullback#345"{Tuple{Int64, Int64}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}}}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Val{2}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.indexed_iterate), Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Int64, Int64}, Tuple{Zygote.ZBack{Zygote.var"#plus_pullback#345"{Tuple{Int64, Int64}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}}}, Zygote.var"#4197#back#1441"{Zygote.var"#1437#1440"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3589#back#1082"{Zygote.var"#1078#1081"}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1353"{2, Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
 [21] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
 [22] top-level scope
    @ show.jl:1128
in expression starting at /home/kenta/workspace/MLUtils.jl/tmp/diff.jl:9

Please note that you'll need to run it on GPU with Zygote.jl 0.6.67. The issue is not reproducible on CPU or Zygote.jl 0.6.66.

The example code is a bit artificial but I noticed this when I tried to use the chunk function of MLUtils.jl. I've filed an issue there but because it is rather caused by a recent change of Zygote.jl I'd like to file a separated issue here.

@ToucheSir
Copy link
Member

Reduced to https://github.com/JuliaDiff/ChainRules.jl/blob/v1.58.0/src/rulesets/Base/indexing.jl#L147. This should happen for any GPU array type:

julia> using JLArrays; x = jl(ones(1));

julia> xs = Union{typeof(x),Nothing}[x, x]
2-element Vector{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}}:
 [1.0]
 [1.0]

julia> view(xs, 1) .+= Ref(x)
ERROR: MethodError: no method matching parent(::Type{SubArray{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}, 0, Vector{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}}, Tuple{Int64}, true}})

Closest candidates are:
  parent(::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S})
   @ LinearAlgebra /mnt/fastdisk/brianc_home/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/adjtrans.jl:341
  parent(::Union{LinearAlgebra.LowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitLowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitUpperTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UpperTriangular{T, S} where S<:AbstractMatrix{T}} where T)
   @ LinearAlgebra /mnt/fastdisk/brianc_home/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/triangular.jl:164
  parent(::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S})
   @ LinearAlgebra /mnt/fastdisk/brianc_home/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/symmetric.jl:275
  ...

Stacktrace:
 [1] backend(#unused#::Type{SubArray{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}, 0, Vector{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}}, Tuple{Int64}, true}})
   @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:151
 [2] backend(x::SubArray{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}, 0, Vector{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}}, Tuple{Int64}, true})
   @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:149
 [3] _copyto!
   @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:70 [inlined]
 [4] materialize!
   @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:46 [inlined]
 [5] materialize!(dest::SubArray{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}, 0, Vector{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}}, Tuple{Int64}, true}, bc::Base.Broadcast.Broadcasted{JLArrays.JLArrayStyle{0}, Nothing, typeof(+), Tuple{SubArray{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}, 0, Vector{Union{Nothing, DenseJLVector{Float64}, JLVector{Float64}}}, Tuple{Int64}, true}, Base.RefValue{JLArray{Float64, 1}}}})
   @ Base.Broadcast ./broadcast.jl:881
 [6] top-level scope
   @ REPL[14]:1

My feeling is that the BroadcastStyle being selected here is incorrect (broadcasting into a CPU array of GPU arrays should not treat the destination like a GPU array). However, I'm not sure if this is something which could be safely changed. You may want to ask the JuliaGPU folks about this behaviour.

@ToucheSir ToucheSir added bug Something isn't working help wanted Extra attention is needed discussion CUDA All things GPU labels Nov 13, 2023
@nomadbl
Copy link

nomadbl commented Dec 9, 2023

I've looked into this a bit.
The first issue is that the example itself suffers from use of an explicit Vector which does not play well with GPUs. This probably creeps in through the rrule and is hard to understand intuitively. This can be seen by the types in the stacktrace, like
∇getindex(x::Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}...

I can modify the example to avoid use of a Vector by doing: (using JLArrays allows me to reproduce all of this on a cpu)

using JLArrays, Zygote

function f(x)
    a = x[1:4]
    b = x[5:8]
    sum(a + b)
end

x = jl(randn(8))
@show Zygote.gradient(f, x)

This gives the stack trace:

ERROR: Not implemented
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] derive(#unused#::Type, N::Int64, a::JLArray{Float64, 1}, osize::Tuple{Int64}, additional_offset::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/dAUOE/src/host/construction.jl:143
  [3] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/base.jl:324 [inlined]
  [4] unsafe_view
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/base.jl:319 [inlined]
  [5] view
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/base.jl:315 [inlined]
  [6] ∇getindex!(dx::JLArray{Float64, 1}, dy::JLArray{Float64, 1}, inds::UnitRange{Int64})
    @ ChainRules ~/.julia/packages/ChainRules/DSuXy/src/rulesets/Base/indexing.jl:180
  [7] ∇getindex(x::JLArray{Float64, 1}, dy::JLArray{Float64, 1}, inds::UnitRange{Int64})
    @ ChainRules ~/.julia/packages/ChainRules/DSuXy/src/rulesets/Base/indexing.jl:89
  [8] (::ChainRules.var"#1583#1585"{JLArray{Float64, 1}, JLArray{Float64, 1}, Tuple{UnitRange{Int64}}})()
    @ ChainRules ~/.julia/packages/ChainRules/DSuXy/src/rulesets/Base/indexing.jl:69
  [9] unthunk
    @ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
 [10] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1583#1585"{JLArray{Float64, 1}, JLArray{Float64, 1}, Tuple{UnitRange{Int64}}}}, ChainRules.var"#1582#1584"{JLArray{Float64, 1}, JLArray{Float64, 1}, Tuple{UnitRange{Int64}}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:237
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
 [12] map
    @ ./tuple.jl:275 [inlined]
 [13] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:111 [inlined]
 [14] ZBack
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
 [15] Pullback
    @ ./REPL[12]:3 [inlined]
 [16] (::Zygote.Pullback{Tuple{typeof(g), JLArray{Float64, 1}}, Tuple{Zygote.var"#3589#back#1082"{Zygote.var"#1078#1081"}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1581"{JLArray{Float64, 1}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1581"{JLArray{Float64, 1}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.var"#4197#back#1441"{Zygote.var"#1437#1440"{JLArray{Float64, 1}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), JLArray{Float64, 1}}, Tuple{Zygote.var"#3589#back#1082"{Zygote.var"#1078#1081"}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1581"{JLArray{Float64, 1}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1581"{JLArray{Float64, 1}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.var"#4197#back#1441"{Zygote.var"#1437#1440"{JLArray{Float64, 1}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
 [18] gradient(f::Function, args::JLArray{Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
 [19] top-level scope
    @ show.jl:1128

and this points to the use of view in the pullback functions of getindex.
The last several stack frames can be reproduced with the following MWE

using JLArrays
dx = jl(randn(8))
view(dx, 1) .= 1

I suggest replacing all such views with setting up a cpu array:

for ind in eachindex(dx_cpu)
           dx_cpu[ind] = ind in inds ? 1 : 0
end

and transferring it to gpu
copyto!(dx, dx_cpu)

followed by multiplication with the gradient

dx .*= dy

@ToucheSir
Copy link
Member

We can't use mutation like this because the rules themselves need to be differentiable. The correct fix IMO would be to get GPUArrays broadcasting to ignore this particular case, which is why I mentioned bringing the issue up on the JuliaGPU side above.

@nomadbl
Copy link

nomadbl commented Dec 10, 2023

Is it possible to use a Buffer to circumvent the mutation issue you mentioned?
Changing the GPUArrays behavior feels wrong to me since I think it could cause some hard to track bugs for what is really an edge case. It would be interesting to hear their response of course.

@ToucheSir
Copy link
Member

Buffer is a Zygote-only thing, so it doesn't make sense to use in ChainRules. It's also slow and doesn't support in-place broadcasts.

I would not call broadcasting into an array of arrays an edge case. It just so happens that this doesn't work for GPU arrays, but that could be a matter of Julia's broadcasting machinery not providing enough information to allow it to work. Opened JuliaGPU/GPUArrays.jl#505 to figure out what's going on.

In the meantime, it looks like a simple fix is to change the aforementioned line in ChainRules (which is in a function that allows mutation) to use a code path which avoids broadcasting. That's what the linked ChainRules PR does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CUDA All things GPU discussion help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants