Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The chunk function is not differentiable on GPU #170

Closed
bicycle1885 opened this issue Nov 7, 2023 · 10 comments
Closed

The chunk function is not differentiable on GPU #170

bicycle1885 opened this issue Nov 7, 2023 · 10 comments

Comments

@bicycle1885
Copy link
Member

I found that operations involving the chunk function are not differentiable on GPU.

using CUDA, Flux

struct Model
    layers
end

function Model()
    dense = Dense(3 => 8)
    Model((;dense))
end

Flux.@functor Model

function (model::Model)(x)
    y = model.layers.dense(x)
    a, b = Flux.chunk(y, size = [4, 4], dims = 1)
    sum(a + b)
end

model = Model()
x = randn(Float32, 3, 10)

x, model = gpu((x, model))
@show model(x)
Flux.withgradient(model -> model(x), model)

When I try to run this, I see the following error:

ERROR: LoadError: MethodError: no method matching parent(::Type{SubArray{Union{ChainRulesCore.ZeroTangent, CuMatrix{Float32, CUDA.Mem.DeviceBuffer}, DenseCuMatrix{Float32, CUDA.Mem.DeviceBuffer}}, 0, Vector{Union{ChainRulesCore.ZeroTangent, CuMatrix{Float32, CUDA.Mem.DeviceBuffer}, DenseCuMatrix{Float32, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64}, true}})

Closest candidates are:
  parent(!Matched::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/adjtrans.jl:341
  parent(!Matched::Union{LinearAlgebra.LowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitLowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitUpperTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UpperTriangular{T, S} where S<:AbstractMatrix{T}} where T)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/triangular.jl:164
  parent(!Matched::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/symmetric.jl:275
  ...

Full error message: log.txt

My environment is:

julia> versioninfo()
Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: ]Linux (x86_64-linux-gnu)
  CPU: 16 × Intel(R) Xeon(R) Gold 6134 CPU @ 3.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake-avx512)
  Threads: 1 on 16 virtual cores
Environment:
  JULIA_PROJECT = @.

(tmp) pkg> st
Status `~/tmp/Project.toml`
  [052768ef] CUDA v5.0.0
  [587475ba] Flux v0.14.6
  [02a925ec] cuDNN v1.2.0

Manifest.toml and Project.toml are the followings (the file name extensions are replaced for uploading).
Manifest.txt
Project.txt

@bicycle1885
Copy link
Member Author

I found that this might be caused by Zygote.jl 0.6.67 because the problem goes away when I downgrade Zygote.jl to 0.6.66.

@mcabbott
Copy link
Contributor

mcabbott commented Nov 7, 2023

This looks like a bug in the CR rrule being used here, after Zygote deleted its rule. Any chance you can isolate it further, e.g. to a single getindex call which gives a similar error?

@bicycle1885
Copy link
Member Author

I'm not sure what you expect. Can you elaborate on this? Then, I'll test it soon.

to a single getindex call which gives a similar error?

@mcabbott
Copy link
Contributor

mcabbott commented Nov 8, 2023

Sorry, what I mean is that chunk must end up doing some indexing, maybe like x[:,1], which uses the rule for getindex, which shows up in the stacktrace. I think the error can probably be reproduced by something like gradient(x -> sum(abs2, x[:, 2:3]), cu(rand(2,3))) but I'm not sure what the indices needed are.

This is the relevant bit of the stacktrace:

  [4] materialize!
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:46 [inlined]
  [5] materialize!
    @ ./broadcast.jl:881 [inlined]
  [6] ∇getindex!(dx::Vector{Union{ChainRulesCore.ZeroTangent, CuMatrix{Float32, CUDA.Mem.DeviceBuffer}, DenseCuMatrix{Float32, CUDA.Mem.DeviceBuffer}}}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:147
  [7] ∇getindex(x::Vector{SubArray{Float32, 2, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:89

Here Vector{SubArray{..., CuArray means perhaps something like x = collect(eachcol(cu(rand(2,3))) is being indexed.

@bicycle1885
Copy link
Member Author

Thanks. I've test two other cases that do more direct indexing without calling chunk. One is a, b = y[1:4,:], y[5:8,:] and the other is a, b = view(y, 1:4, :), view(y, 5:8, :), and I confirmed that both of them work without any error on GPU.

function (model::Model)(x)
    y = model.layers.dense(x)

    # ERROR: LoadError: MethodError: no method matching parent(...
    #a, b = Flux.chunk(y, size = [4, 4], dims = 1)

    # this works
    #a, b = y[1:4,:], y[5:8,:]

    # this works
    a, b = view(y, 1:4, :), view(y, 5:8, :)

    sum(a + b)
end

@bicycle1885
Copy link
Member Author

I discovered that the following pattern doesn't work. I guess the lowering implicitly inserts some getindex calls hindering differentiation.

    a, b = [y[1:4,:], y[5:8,:]]

@bicycle1885
Copy link
Member Author

So, I reduced the code to the following. As in the case above, the error disappears if I use Zygote.jl 0.6.66 instead of 0.6.67.

using CUDA, Zygote

function f(x)
    a, b = [x[1:4], x[5:8]]
    sum(a + b)
end

x = cu(randn(8))
@show f(x)
@show Zygote.gradient(f, x)

@mcabbott
Copy link
Contributor

mcabbott commented Nov 9, 2023

Thanks, that's helpful!

@ToucheSir
Copy link
Contributor

Thanks for the MWE, will follow-up on the Zygote issue.

@CarloLucibello
Copy link
Member

The example in the OP works fine on the latest version of the packages

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants