-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add dims for chunk and test_zygote #47
Conversation
You could hide it in a function, since it has only indices, or try to write the gradient for CR. But here, I think it's worth writing the gradient for this whole function. Much like the rule for |
The test facility comes from JuliaDiff/ChainRulesTestUtils.jl#211 |
Something like
doesn't work because actually the problem is not in differentiating |
The rrule needed should resemble this |
I implemented the simplest approach for the gradient of julia> using MLUtils, BenchmarkTools, Test, ChainRulesCore, ChainRulesTestUtils, Zygote
julia> x = rand(100, 200);
julia> @btime sum(chunk($x, 10)[1]);
1.374 μs (56 allocations: 2.59 KiB)
julia> @btime gradient(x -> sum(chunk(x, 10)[1]), $x);
207.899 μs (1598 allocations: 231.70 KiB) Maybe we can keep this for the time being and open a performance issue |
Implemented the rrule, now it is much better julia> x = rand(100, 200);
julia> @btime gradient(x -> sum(chunk(x, 10)[1]), $x);
19.878 μs (197 allocations: 166.16 KiB)
`` |
@mcabbott looks good? |
Codecov Report
@@ Coverage Diff @@
## main #47 +/- ##
==========================================
- Coverage 88.91% 88.80% -0.12%
==========================================
Files 13 13
Lines 379 411 +32
==========================================
+ Hits 337 365 +28
- Misses 42 46 +4
Continue to review full report at Codecov.
|
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
For some reason, github doesn't let me ask for your review @mcabbott ... is this good now? |
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Trying to fix FluxML/Flux.jl#1841 using @mcabbott's suggestion, but Zygote has trouble differentiating through
Iterators.partition
, I getAs a workaround I can define a custom partition function I guess.
cc @theabhirath @darsnack