Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients for prod(x; dims) and cumprod(x), which take keywords & allow for zero entries #334

Closed
wants to merge 5 commits into from

Conversation

mcabbott
Copy link
Member

In the tagged version 0.5.3, there is a clever function for gradient of prod(x) allowing for it to have zero entries. This PR adds a similar function for prod(x,dim), fixing the following problem:

Flux.Tracker.gradcheck(prod, [1.2, 2.3, 0.0, 4.5]) ## true
Flux.Tracker.gradcheck(z -> sum(prod(z,1)), [1.2, 2.3, 0.0, 4.5]) ## false

On master branch, this neat function was (by mistake?) applied to prod(z,dim) instead, reversing false <--> true in the above.

I don't fully understand what (nobacksies(:sum, is doing so I left it alone.

I also left a naiive fallback function for prod(x, dims) with several dims=(2,3) etc. to pass tests. Perhaps I should add tests with zeros.

However these functions are much slower than the naiive gradient --- 1000 times slower for me on Julia 0.6.3, although only 40 times slower on 0.7:

∇prod(x) = prod(x) ./ x

∇prod0(x::Vector) = .*(circshift.([x], 1:length(x)-1)...)

∇prod0(x::Array, dim::Int) =.*(circshift.([x], tuple.(Iterators.repeated(0,dim-1)...,1:size(x,dim)-1))...)

using BenchmarkTools
rr = rand(5)

@btime ∇prod($rr) 
@btime ∇prod0($rr)
@btime ∇prod0($rr,1)

For prod(x) at least, it would be easy to add an if statement that only runs the slow version prod(x)==0. For prod(x,dim) this starts to sound messy.

@mcabbott
Copy link
Member Author

mcabbott commented Jul 23, 2018

Here's a sketch of a much faster prod(x) gradient:

function ∇prod1(x::Vector)
  f = prod(x)  ## from back_ ideally
  f!=0 && return f ./ x
  z = find(iszero, x)  ## type unstable? 
  length(z)>1 && return zeros(x)
  ∇ = zeros(x)
  ∇[z[1]] = prod(x[i] for i in eachindex(x) if i!=z[1])
  return ∇
end 

∇prod1(x::Array, dim) = mapslices(∇prod1, x, dim)

@btime ∇prod1($rr)

rr[2] = 0;

@btime ∇prod0($rr)
@btime ∇prod1($rr)

What's the policy on long messy functions in array.jl?

And how do I get the result of the forward pass in the new @grad way of writing?

@MikeInnes
Copy link
Member

Looks useful, would you be able to update the PR for latest master? Then maybe I can take a look at performance.

The only real worry with long messy functions is whether they are differentiable and can work on the GPU, both of which basically imply using high-level operations rather than using indexing and loops. But in cases where there's a good speedup it's acceptable to make special cases; we'd just have to add a GPU-compatible kernel as well.

One way to get the forward result is with forward:

julia> y, back = Tracker.forward(x -> x.*3, rand(5));

julia> y
Tracked 5-element Array{Float64,1}:
 0.2520329546947273
 2.1928414791486857
 2.768906524399821 
 2.805451903950836 
 2.064442554488768 

julia> back(rand(5))[1]
Tracked 5-element Array{Float64,1}:
 1.2135785715420675  
 1.1689372317538684  
 0.062289695490364894
 0.011664861853422304
 2.2559453870972246  

@mcabbott
Copy link
Member Author

mcabbott commented Sep 7, 2018

Great, will update this when I get a minute. And will leave array operation fall-backs.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 8, 2018

Taking a look at this, in order to use keyword dims I think I want a single @grad and dispatch on several inner functions like _std, something like this:

@grad prod(xs; dims=:) = _prod(xs, dims)
_prod(xs::T1, ::Colon) = begin p = prod(xs.data); p, Δ -> (nobacksies(:prod, ∇prod_new(xs.data, p, Δ) ),) end
_prod(xs::T1, dims) = prod(xs.data; dims=dims), Δ -> (nobacksies(:prod, mapslices(∇prod_new, xs.data; dims=dims) .* Δ), )
_prod(xs::T2, ::Colon) = ## fall-back for other than dense CPU arrays
_prod(xs::T2, dims::Int) = ## ...

I'd like this to also select whether to call the function with loops, or generic-array versions. What should T1 and T2 be for this? Naiively T1 = TrackedArray{<:Array} but that's not right.

Edit: if I dispatch on xs.data then perhaps like this? But DenseArray is too narrow, it excludes transposed arrays...

@grad prod(xs; dims=:) = _prod(xs.data, prod(xs.data), dims)
_prod(xs::DenseArray, p, ::Colon) = p, Δ -> (nobacksies(:prod, ∇prod_new(xs, p, Δ) ),) 
_prod(xs, p, ::Colon) = ## fall-back for other than dense CPU arrays

`prod(x::Array)` is much faster, and handles zeros correctly. 
`prod(x::Array; dims)` now uses a keyword, and just uses `mapslices(∇prod, x; dims)`. 
`prod(x::AbstractArray)` falls back to earlier `circshift...` methods. 
It would be easy to make `prod(f::Function, x::Array)` use fast `∇prod` but I didn't see a neat way to fall back, so left it alone. 

`cumprod(x)` now has a gradient; previously this gave Array{TrackedReal}. 
`cumprod(x; dims::Int)` works by something like `mapslices(∇cumprod, x, p, Δ; dims)` which I cooked up. 
(This was the case I originally wanted all this for!)
@mcabbott mcabbott changed the title Gradient for prod(x, dim) which allows for zero Gradients for prod(x; dims) and cumprod(x), which take keywords & allow for zero entries Sep 8, 2018
@mcabbott
Copy link
Member Author

mcabbott commented Sep 8, 2018

OK here's a tidier version. Not perfect but perhaps an improvement.

Most immediately the following incorrect & missing gradients are fixed:

using Flux
using Flux.Tracker: back!, gradcheck
rr = rand(10); rr[3]=0;

gradcheck(prod, rr) ## false on master 
gradcheck(z -> sum(prod(z; dims=1)), rr) ## true on master

cumprod(param(rand(3))) isa TrackedArray ## false, Vector{TrackedReal} on master
gradcheck(z -> sum(cumprod(z)), rr) ## true on master

For prod(::Vector) at least, this is much faster, 20-200 times. For more complicated case I call mapslices() which is quite slow -- if you change below to rr=rand(5) then ∇prod_map is slower than fallback ∇prod_dim. But perhaps mapslices will be improved.

Fallback methods are called when x is not an Array. Probably this is tighter than necessary, as they will work fine for transposed etc.

And I haven't thought about 2nd derivatives at all, I left nobacksies(:prod, ... there, and inserted data(Δ) if needed.

using BenchmarkTools
using Flux.Tracker: ∇prod, ∇prod_all, ∇prod_dim, ∇cumprod ## from this PR
∇prod_map(x, dims) = mapslices(∇prod, x; dims=dims) ## not type-stable :( 

@btime ∇prod($rr) ## 977.333 ns 
@btime ∇prod($(rand(10))) ## 92.597 ns -- much quicker if no zeros
@btime ∇prod_all($rr) ## 17.383 μs -- circshift fallback

@btime ∇prod_map($rr, 1) ## 7.848 μs -- mapslices is slow, but would be avoided
@btime ∇prod_dim($rr, $(prod(rr, dims=1)), 1) ## 21.568 μs -- circshift fallback

@btime ∇cumprod($rr) ## 207.054 ns

BTW the gradient for cumprod is inspired by tensorflow/tensorflow#3862 (comment) although I actually think that formula is incorrect. See also Theano/Theano#5197 for others wondering about this.

Edit: I improved my pseudo-mapslices, to speed up things like prod(matrix, dims=1). Some tests on matrices:

nn = rand(10,10); ## all nonzero
mm = rand(10,10); mm[2,3]=0;

@btime ∇prod($mm) ## 8.099 μs
@btime ∇prod($nn) ## 410.695 ns -- much faster without zeros
@btime ∇prod_all($mm) ## 44.752 ms -- circshift fallback very slow here

@btime ∇prod_map($mm, 1) ## 19.632 μs  -- mapslices is slow
@btime ∇prod($mm, Val(1)) ## 3.357 μs -- now avoiding mapslices for dims::Int
@btime ∇prod_dim($mm, $(prod(mm, dims=1)), 1) ## 25.419 μs --  circshift fallback for dims::Int

@mcabbott
Copy link
Member Author

Closed in favour of #524.

@mcabbott mcabbott closed this Dec 20, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants