Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dims for chunk and test_zygote #47

Merged
merged 13 commits into from
Feb 10, 2022
Merged

add dims for chunk and test_zygote #47

merged 13 commits into from
Feb 10, 2022

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Feb 9, 2022

Trying to fix FluxML/Flux.jl#1841 using @mcabbott's suggestion, but Zygote has trouble differentiating through Iterators.partition, I get

julia> test_zygote(chunk, rand(10), 3)
test_rrule: chunk on Vector{Float64},Int64: Error During Test at /home/carlo/.julia/packages/ChainRulesTestUtils/XI7i2/src/testers.jl:195
  Got exception outside of a @test
  MethodError: no method matching size(::Base.Iterators.PartitionIterator{Base.OneTo{Int64}})
  Closest candidates are:
    size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}) at ~/julia/julia-1.7.1/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:567
    size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer) at ~/julia/julia-1.7.1/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:566
    size(::Union{LinearAlgebra.Cholesky, LinearAlgebra.CholeskyPivoted}) at ~/julia/julia-1.7.1/share/julia/stdlib/v1.7/LinearAlgebra/src/cholesky.jl:494
    ...
  Stacktrace:
    [1] axes
      @ ./abstractarray.jl:95 [inlined]
    [2] _tryaxes(x::Base.Iterators.PartitionIterator{Base.OneTo{Int64}})
      @ Zygote ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:184
    [3] map
      @ ./tuple.jl:221 [inlined]
    [4] ∇map(cx::Zygote.Context, f::MLUtils.var"#70#71"{Int64, Vector{Float64}}, args::Base.Iterators.PartitionIterator{Base.OneTo{Int64}})
      @ Zygote ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:199
    [5] _pullback(cx::Zygote.Context, #unused#::typeof(collect), g::Base.Generator{Base.Iterators.PartitionIterator{Base.OneTo{Int64}}, MLUtils.var"#70#71"{Int64, Vector{Float64}}})
      @ Zygote ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:244
    [6] _pullback
      @ ~/.julia/dev/MLUtils/src/utils.jl:165 [inlined]
    [7] _pullback(::Zygote.Context, ::MLUtils.var"##chunk#69", ::Int64, ::typeof(chunk), ::Vector{Float64}, ::Int64)
      @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
    [8] _pullback
      @ ~/.julia/dev/MLUtils/src/utils.jl:164 [inlined]
    [9] _pullback(::Zygote.Context, ::typeof(chunk), ::Vector{Float64}, ::Int64)
      @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0

As a workaround I can define a custom partition function I guess.

cc @theabhirath @darsnack

@mcabbott
Copy link
Contributor

mcabbott commented Feb 9, 2022

You could hide it in a function, since it has only indices, or try to write the gradient for CR.

But here, I think it's worth writing the gradient for this whole function. Much like the rule for eachcol(x), it can allocate just one similar(x), not one per slice.

@CarloLucibello
Copy link
Member Author

The test facility comes from JuliaDiff/ChainRulesTestUtils.jl#211

@CarloLucibello
Copy link
Member Author

Something like

_partition(x...) = Iterators.partition(x...) 
@non_differentiable _partition(x...)

doesn't work because actually the problem is not in differentiating partition but in ∇map over PartitionIterator

@CarloLucibello
Copy link
Member Author

@CarloLucibello
Copy link
Member Author

I implemented the simplest approach for the gradient of chunk but the performance is not good

julia> using MLUtils, BenchmarkTools, Test, ChainRulesCore, ChainRulesTestUtils, Zygote

julia> x = rand(100, 200);

julia> @btime sum(chunk($x, 10)[1]);
  1.374 μs (56 allocations: 2.59 KiB)

julia> @btime gradient(x -> sum(chunk(x, 10)[1]), $x);
  207.899 μs (1598 allocations: 231.70 KiB)

Maybe we can keep this for the time being and open a performance issue

@CarloLucibello
Copy link
Member Author

Implemented the rrule, now it is much better

julia> x = rand(100, 200);

julia> @btime gradient(x -> sum(chunk(x, 10)[1]), $x);
  19.878 μs (197 allocations: 166.16 KiB)
`` 

@CarloLucibello
Copy link
Member Author

@mcabbott looks good?

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Feb 9, 2022

Codecov Report

Merging #47 (1c6a8c5) into main (ea45061) will decrease coverage by 0.11%.
The diff coverage is 85.29%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #47      +/-   ##
==========================================
- Coverage   88.91%   88.80%   -0.12%     
==========================================
  Files          13       13              
  Lines         379      411      +32     
==========================================
+ Hits          337      365      +28     
- Misses         42       46       +4     
Impacted Files Coverage Δ
src/utils.jl 93.75% <85.29%> (-3.13%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ea45061...1c6a8c5. Read the comment docs.

CarloLucibello and others added 2 commits February 9, 2022 19:11
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@CarloLucibello
Copy link
Member Author

For some reason, github doesn't let me ask for your review @mcabbott ... is this good now?

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@CarloLucibello CarloLucibello merged commit 4a60808 into main Feb 10, 2022
@CarloLucibello CarloLucibello mentioned this pull request Feb 15, 2022
3 tasks
@CarloLucibello CarloLucibello deleted the cl/chunk branch June 28, 2022 03:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Flux.chunk for multi-dimensional arrays
3 participants