-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradients for prod(x; dims) and cumprod(x), which take keywords & allow for zero entries #334
Conversation
Here's a sketch of a much faster
What's the policy on long messy functions in array.jl? And how do I get the result of the forward pass in the new |
Looks useful, would you be able to update the PR for latest master? Then maybe I can take a look at performance. The only real worry with long messy functions is whether they are differentiable and can work on the GPU, both of which basically imply using high-level operations rather than using indexing and loops. But in cases where there's a good speedup it's acceptable to make special cases; we'd just have to add a GPU-compatible kernel as well. One way to get the forward result is with julia> y, back = Tracker.forward(x -> x.*3, rand(5));
julia> y
Tracked 5-element Array{Float64,1}:
0.2520329546947273
2.1928414791486857
2.768906524399821
2.805451903950836
2.064442554488768
julia> back(rand(5))[1]
Tracked 5-element Array{Float64,1}:
1.2135785715420675
1.1689372317538684
0.062289695490364894
0.011664861853422304
2.2559453870972246 |
Great, will update this when I get a minute. And will leave array operation fall-backs. |
Taking a look at this, in order to use keyword
I'd like this to also select whether to call the function with loops, or generic-array versions. What should Edit: if I dispatch on
|
`prod(x::Array)` is much faster, and handles zeros correctly. `prod(x::Array; dims)` now uses a keyword, and just uses `mapslices(∇prod, x; dims)`. `prod(x::AbstractArray)` falls back to earlier `circshift...` methods. It would be easy to make `prod(f::Function, x::Array)` use fast `∇prod` but I didn't see a neat way to fall back, so left it alone. `cumprod(x)` now has a gradient; previously this gave Array{TrackedReal}. `cumprod(x; dims::Int)` works by something like `mapslices(∇cumprod, x, p, Δ; dims)` which I cooked up. (This was the case I originally wanted all this for!)
OK here's a tidier version. Not perfect but perhaps an improvement. Most immediately the following incorrect & missing gradients are fixed: using Flux
using Flux.Tracker: back!, gradcheck
rr = rand(10); rr[3]=0;
gradcheck(prod, rr) ## false on master
gradcheck(z -> sum(prod(z; dims=1)), rr) ## true on master
cumprod(param(rand(3))) isa TrackedArray ## false, Vector{TrackedReal} on master
gradcheck(z -> sum(cumprod(z)), rr) ## true on master For Fallback methods are called when And I haven't thought about 2nd derivatives at all, I left using BenchmarkTools
using Flux.Tracker: ∇prod, ∇prod_all, ∇prod_dim, ∇cumprod ## from this PR
∇prod_map(x, dims) = mapslices(∇prod, x; dims=dims) ## not type-stable :(
@btime ∇prod($rr) ## 977.333 ns
@btime ∇prod($(rand(10))) ## 92.597 ns -- much quicker if no zeros
@btime ∇prod_all($rr) ## 17.383 μs -- circshift fallback
@btime ∇prod_map($rr, 1) ## 7.848 μs -- mapslices is slow, but would be avoided
@btime ∇prod_dim($rr, $(prod(rr, dims=1)), 1) ## 21.568 μs -- circshift fallback
@btime ∇cumprod($rr) ## 207.054 ns BTW the gradient for cumprod is inspired by tensorflow/tensorflow#3862 (comment) although I actually think that formula is incorrect. See also Theano/Theano#5197 for others wondering about this. Edit: I improved my pseudo-mapslices, to speed up things like nn = rand(10,10); ## all nonzero
mm = rand(10,10); mm[2,3]=0;
@btime ∇prod($mm) ## 8.099 μs
@btime ∇prod($nn) ## 410.695 ns -- much faster without zeros
@btime ∇prod_all($mm) ## 44.752 ms -- circshift fallback very slow here
@btime ∇prod_map($mm, 1) ## 19.632 μs -- mapslices is slow
@btime ∇prod($mm, Val(1)) ## 3.357 μs -- now avoiding mapslices for dims::Int
@btime ∇prod_dim($mm, $(prod(mm, dims=1)), 1) ## 25.419 μs -- circshift fallback for dims::Int |
Closed in favour of #524. |
In the tagged version 0.5.3, there is a clever function for gradient of
prod(x)
allowing for it to have zero entries. This PR adds a similar function forprod(x,dim)
, fixing the following problem:On master branch, this neat function was (by mistake?) applied to
prod(z,dim)
instead, reversing false <--> true in the above.I don't fully understand what
(nobacksies(:sum,
is doing so I left it alone.I also left a naiive fallback function for
prod(x, dims)
with severaldims=(2,3)
etc. to pass tests. Perhaps I should add tests with zeros.However these functions are much slower than the naiive gradient --- 1000 times slower for me on Julia 0.6.3, although only 40 times slower on 0.7:
For
prod(x)
at least, it would be easy to add an if statement that only runs the slow versionprod(x)==0
. Forprod(x,dim)
this starts to sound messy.