Skip to content

Commit

Permalink
Merge 74d5ca1 into 5aaae8a
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Sep 17, 2020
2 parents 5aaae8a + 74d5ca1 commit ad2ce5d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.19"
version = "0.7.20"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
15 changes: 15 additions & 0 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}
return A * B, times_pullback
end

function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B)))
end
return A * B, times_pullback
end


#####
##### `pinv`
#####
Expand Down
22 changes: 15 additions & 7 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,21 @@
rrule_test(tr, randn(), (randn(N, N), randn(N, N)))
end
@testset "*" begin
dims = [3,4,5]
for n in dims, m in dims, p in dims
n > 3 && n == m == p && continue # don't need to test square case multiple times
A = randn(m, n)
B = randn(n, p)
= randn(m, p)
rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p)))
@testset "Matrix-Matrix" begin
dims = [3,4,5]
for n in dims, m in dims, p in dims
n > 3 && n == m == p && continue # don't need to test square case multiple times
A = randn(m, n)
B = randn(n, p)
= randn(m, p)
rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p)))
end
end
@testset "Scalar-AbstractArray" begin
for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5))
rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims)))
rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2))
end
end
end
@testset "$f" for f in [/, \]
Expand Down

0 comments on commit ad2ce5d

Please sign in to comment.