Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataLoader incompatible with Flux/Zygote #127

Closed
simonmandlik opened this issue Oct 26, 2022 · 2 comments
Closed

DataLoader incompatible with Flux/Zygote #127

simonmandlik opened this issue Oct 26, 2022 · 2 comments

Comments

@simonmandlik
Copy link

Due to the try/catch in the implementation of DataLoader, Zygote.jl cannot differentiate through the iteration:

using Flux, MLUtils

x = randn(10, 10)
m = Dense(10, 10)
ps = Flux.params(m)
mbs = MLUtils.DataLoader(x, batchsize=4, shuffle=true)

julia> mb_grad = gradient(() -> sum(m(first(mbs))), ps)
ERROR: Compiling Tuple{MLUtils.var"##BatchView#28", Int64, Bool, Val{nothing}, Type{BatchView}, ObsView{Matrix{Float64}, Vector{Int64}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:121
  [3] #Primal#23
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:330
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/emit.jl:101
  [6] #s2924#1068
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s2924#1068"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [9] _pullback
    @ ~/.julia/packages/MLUtils/Th9Y3/src/batchview.jl:92 [inlined]
 [10] _pullback(::Zygote.Context{true}, ::Core.var"#Type##kw", ::NamedTuple{(:batchsize, :partial, :collate), Tuple{Int64, Bool, Val{nothing}}}, ::Type{BatchView}, ::ObsView{Matrix{Float64}, Vector{Int64}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/packages/MLUtils/Th9Y3/src/eachobs.jl:161 [inlined]
 [12] _pullback(ctx::Zygote.Context{true}, f::typeof(iterate), args::DataLoader{Matrix{Float64}, Random._GLOBAL_RNG, Val{nothing}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./abstractarray.jl:424 [inlined]
 [14] _pullback(ctx::Zygote.Context{true}, f::typeof(first), args::DataLoader{Matrix{Float64}, Random._GLOBAL_RNG, Val{nothing}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./REPL[10]:1 [inlined]
 [16] _pullback(::Zygote.Context{true}, ::var"#5#6")
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [17] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:373
 [18] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
 [19] top-level scope
    @ REPL[10]:1
 [20] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

(jl_PGRlDC) pkg> st
Status `/private/var/folders/pr/bshcrht5423cbhtz_ls_h_7r0000gp/T/jl_PGRlDC/Project.toml`
  [587475ba] Flux v0.13.6
  [f1d291b0] MLUtils v0.2.11

But similar code worked in old MLDataPattern:

using Flux, MLDataPattern

x = randn(10, 10)
m = Dense(10, 10)
ps = Flux.params(m)
mbs = RandomBatches(x, size=4)

julia> mb_grad = gradient(() -> sum(m(first(mbs))), ps)
Grads(...)

(jl_NW4gr7) pkg> st
Status `/private/var/folders/pr/bshcrht5423cbhtz_ls_h_7r0000gp/T/jl_NW4gr7/Project.toml`
  [587475ba] Flux v0.13.6
  [9920b226] MLDataPattern v0.5.5
@ToucheSir
Copy link
Contributor

That the old dataloader worked is probably a happy accident. Is there any reason you can't pull a batch from the dataloader (e.g. extract first(mbs) outside of lambda) before calling gradient? Even if it did work before, it was likely causing unncessary performance issues because Zygote has to differentiate through the DataLoader code.

@simonmandlik simonmandlik changed the title DataLoader uncompatible with Flux/Zygote DataLoader incompatible with Flux/Zygote Oct 27, 2022
@simonmandlik
Copy link
Author

No, not really. The original idea was to make code like this work out of the box with highest-level API (like Flux.train!, but looking at its code it also loops over minibatches and computes gradient for each separately).

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants