Skip to content

Commit

Permalink
Merge 8fbe739 into 4136420
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 28, 2019
2 parents 4136420 + 8fbe739 commit 136d3c8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.2-DEV"
version = "0.2.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 3 additions & 3 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function frule(::typeof(sum), x)
return sum(x), sum_pushforward
end

function rrule(::typeof(sum), x)
function rrule(::typeof(sum), x::AbstractArray{<:Real})
function sum_pullback(ȳ)
return (NO_FIELDS, @thunk(fill(ȳ, size(x))))
end
Expand All @@ -67,15 +67,15 @@ end
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
function sum_pullback(ȳ)
NO_FIELDS, DNE(), last(mr_pullback(ȳ))
return NO_FIELDS, DNE(), last(mr_pullback(ȳ))
end
return y, sum_pullback
end

function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
y, inner_pullback = rrule(sum, identity, x; dims=dims)
function sum_pullback(ȳ)
NO_FIELDS, last(inner_pullback(ȳ))
return NO_FIELDS, last(inner_pullback(ȳ))
end
return y, sum_pullback
end
Expand Down

0 comments on commit 136d3c8

Please sign in to comment.