Skip to content

Commit

Permalink
add adjoint for reduce hvcat (#236)
Browse files Browse the repository at this point in the history
* add adjoint for reduce hvcat

* Apply suggestions from code review

Co-authored-by: Seth Axen <seth.axen@gmail.com>

* Change return type to DoesNotExist

Co-authored-by: Seth Axen <seth.axen@gmail.com>

* test with rrule_test

* use randn everywhere

* bump version

Co-authored-by: Seth Axen <seth.axen@gmail.com>
  • Loading branch information
piever and sethaxen committed Jul 16, 2020
1 parent e9aa66e commit 026a118
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.10"
version = "0.7.11"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
26 changes: 26 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...)
return hcat(A, Bs...), hcat_pullback
end

function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
function reduce_hcat_pullback(ΔY)
sizes = size.(As, 2)
cumsizes = cumsum(sizes)
∂As = map(cumsizes, sizes) do post, diff
pre = post - diff + 1
return ΔY[:, pre:post]
end
return (NO_FIELDS, DoesNotExist(), ∂As)
end
return reduce(hcat, As), reduce_hcat_pullback
end

#####
##### `vcat`
#####
Expand All @@ -57,6 +70,19 @@ function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...)
return vcat(A, Bs...), vcat_pullback
end

function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
function reduce_vcat_pullback(ΔY)
sizes = size.(As, 1)
cumsizes = cumsum(sizes)
∂As = map(cumsizes, sizes) do post, diff
pre = post - diff + 1
return ΔY[pre:post, :]
end
return (NO_FIELDS, DoesNotExist(), ∂As)
end
return reduce(vcat, As), reduce_vcat_pullback
end

#####
##### `fill`
#####
Expand Down
24 changes: 24 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ end
@test dC view(H̄, :, 4:6)
end

@testset "reduce hcat" begin
A = randn(3, 2)
B = randn(3, 1)
C = randn(3, 3)
x = [A, B, C]
H, pullback = rrule(reduce, hcat, x)
@test H == reduce(hcat, x)
= randn(3, 6)
= randn.(size.(x))
rrule_test(reduce, H̄, (hcat, nothing), (x, x̄))
end

@testset "vcat" begin
A = randn(2, 4)
B = randn(1, 4)
Expand All @@ -48,6 +60,18 @@ end
@test dC view(V̄, 4:6, :)
end

@testset "reduce vcat" begin
A = randn(2, 4)
B = randn(1, 4)
C = randn(3, 4)
x = [A, B, C]
V, pullback = rrule(reduce, vcat, x)
@test V == reduce(vcat, x)
= randn(6, 4)
= randn.(size.(x))
rrule_test(reduce, V̄, (vcat, nothing), (x, x̄))
end

@testset "fill" begin
y, pullback = rrule(fill, 44, 4)
@test y == [44, 44, 44, 44]
Expand Down

2 comments on commit 026a118

@sethaxen
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/18025

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.11 -m "<description of version>" 026a118a78e7961a31037c689fd3c6b347fa3efa
git push origin v0.7.11

Please sign in to comment.