Skip to content

Commit

Permalink
Merge #112
Browse files Browse the repository at this point in the history
112: Simplest prod(x; dims) gradient r=dhairyagandhi96 a=mcabbott

The current gradient for `prod(x; dims)` gives incorrect results, this PR fixes it (parallel to  FluxML/Tracker.jl#1 ):
```
julia> using Zygote, ForwardDiff

julia> r = rand(2,3,2);

julia> ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # wrong answer!
2×3×2 Array{Float64,3}:
[:, :, 1] =
 5.93867e-6  4.30525e-6  2.31817e-5
 1.60301e-5  3.2038e-5   8.44331e-6

[:, :, 2] =
 1.95925e-5  6.33622e-5  6.89391e-6
 1.36795e-5  4.19746e-6  4.07989e-6

julia> Zygote.@adjoint function prod(xs; dims = :) # as in this PR
         p = prod(xs; dims = dims)
         p, Δ -> (p ./ xs .* Δ,)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now matches ForwardDiff
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601
```
This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The `circshift(...` operation deleted here was a correct (but slow) gradient for `prod(x)`, but is clearly independent of `dims`. 

The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with `gradtest`?
```
julia> @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
Test Passed

julia> @test gradtest(x -> prod(x), (3,4,5))
Test Passed
```

Co-authored-by: Michael Abbott <me@pseudomac>
  • Loading branch information
bors[bot] and Michael Abbott committed Feb 26, 2020
2 parents d7e8afc + 99244ed commit af498fa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
10 changes: 3 additions & 7 deletions src/lib/array.jl
Expand Up @@ -189,13 +189,9 @@ end
return sum(abs2, X; dims=dims), Δ::Union{Number, AbstractArray}->(nothing, ((2Δ) .* X))
end

@adjoint function prod(xs::AbstractArray{<:Number}; dims = :)
if dims === (:)
prod(xs), Δ -> (prod(xs) ./ xs .* Δ,)
else
prod(xs, dims = dims),
Δ -> (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ,)
end
@adjoint function prod(xs; dims = :)
p = prod(xs; dims = dims)
p, Δ -> (p ./ xs .* Δ,)
end

function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
Expand Down
6 changes: 3 additions & 3 deletions test/gradcheck.jl
Expand Up @@ -64,8 +64,8 @@ Random.seed!(0)
@test gradtest(x -> sum(x[i] for i in 1:length(x)), randn(10))
@test_broken gradtest(x -> sum(i->x[i], 1:length(x)), randn(10)) # https://github.com/FluxML/Zygote.jl/issues/231

@test_broken gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4))

@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
Expand Down Expand Up @@ -112,7 +112,7 @@ end
end

@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(repeat([10], spatial_rank)..., 3, 2)
x = rand(repeat([5], spatial_rank)..., 3, 2)
w = rand(repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
@test gradtest((x, w) -> conv(x, w, cdims), x, w)
Expand Down

0 comments on commit af498fa

Please sign in to comment.