Skip to content

Commit

Permalink
Merge 0f0d928 into c8679c6
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Sep 15, 2020
2 parents c8679c6 + 0f0d928 commit 99cc56e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.18"
version = "0.7.19"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 4 additions & 2 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
#####

function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}})
A_dims = size(A)
function reshape_pullback(Ȳ)
return (NO_FIELDS, reshape(Ȳ, dims), DoesNotExist())
return (NO_FIELDS, reshape(Ȳ, A_dims), DoesNotExist())
end
return reshape(A, dims), reshape_pullback
end

function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
A_dims = size(A)
function reshape_pullback(Ȳ)
∂A = reshape(Ȳ, dims)
∂A = reshape(Ȳ, A_dims)
return (NO_FIELDS, ∂A, fill(DoesNotExist(), length(dims))...)
end
return reshape(A, dims...), reshape_pullback
Expand Down
22 changes: 5 additions & 17 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,11 @@
@testset "reshape" begin
A = rand(4, 5)
B, pullback = rrule(reshape, A, (5, 4))
@test B == reshape(A, (5, 4))
= randn(4, 5)
x = rand(4, 5)
= rand(4, 5)

(s̄, Ā, d̄) = pullback(Ȳ)
@test== NO_FIELDS
@testisa DoesNotExist
@test extern(Ā) == reshape(Ȳ, (5, 4))
= rand(2, 10)

B, pullback = rrule(reshape, A, 5, 4)
@test B == reshape(A, 5, 4)

= randn(4, 5)
(s̄, Ā, d̄1, d̄2) = pullback(Ȳ)
@test== NO_FIELDS
@test d̄1 isa DoesNotExist
@test d̄2 isa DoesNotExist
@test extern(Ā) == reshape(Ȳ, 5, 4)
rrule_test(reshape, ȳ, (x, x̄), ((2, 10), nothing))
rrule_test(reshape, ȳ, (x, x̄), (2, nothing), (10, nothing))
end

@testset "hcat" begin
Expand Down

0 comments on commit 99cc56e

Please sign in to comment.