Skip to content

Commit

Permalink
Merge 84b5f71 into 5586a9b
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 20, 2020
2 parents 5586a9b + 84b5f71 commit ff29066
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 23 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ version = "0.7.27"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -21,12 +23,11 @@ julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "Compat", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"]
test = ["ChainRulesTestUtils", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"]
3 changes: 3 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Reexport
@reexport using ChainRulesCore

using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using Compat
using LinearAlgebra
using LinearAlgebra.BLAS
using Random
Expand All @@ -22,6 +23,8 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

include("rulesets/Core/core.jl")

Expand Down
44 changes: 38 additions & 6 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,55 @@ end
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
function rrule(
::typeof(*),
A::AbstractMatrix{<:CommutativeMulNumber},
B::AbstractMatrix{<:CommutativeMulNumber},
)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
return (
NO_FIELDS,
InplaceableThunk(
@thunk(Ȳ * B'),
-> mul!(X̄, Ȳ, B', true, true)
),
InplaceableThunk(
@thunk(A' * Ȳ),
-> mul!(X̄, A', Ȳ, true, true)
)
)
end
return A * B, times_pullback
end

function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real})
function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ))
return (
NO_FIELDS,
@thunk(dot(Ȳ, B)'),
InplaceableThunk(
@thunk(A' * Ȳ),
-> mul!(X̄, conj(A), Ȳ, true, true)
)
)
end
return A * B, times_pullback
end

function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real)
function rrule(
::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber
)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B)))
return (
NO_FIELDS,
InplaceableThunk(
@thunk(A' * Ȳ),
-> mul!(X̄, conj(A), Ȳ, true, true)
),
@thunk(dot(Ȳ, B)'),
)
end
return A * B, times_pullback
end
Expand Down
53 changes: 38 additions & 15 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,49 @@
rrule_test(inv, randn(T, N, N), (B, randn(T, N, N)))
end

@testset "*" begin
@testset "Matrix-Matrix" begin
dims = [3,4,5]
for n in dims, m in dims, p in dims
# don't need to test square case multiple times
n > 3 && n == m == p && continue
A = randn(m, n)
B = randn(n, p)
= randn(m, p)
rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p)))
end
@testset "*: $T" for T in (Float64, ComplexF64)
(a) = round.(5*randn(T, a)) # Helper to generate nice random values
(a, b) = ((a, b)) # matrix
() = only((())) # scalar

(a) = ((a), (a)) # Helper to generate random matrix and its cotangent
(a, b) = ((a, b)) #matrix
() = (()) # scalar

@testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5))
rrule_test(*, (dims), (), (dims))
rrule_test(*, (dims), (dims), ())
end
@testset "Scalar-AbstractArray" begin
for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5))
rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims)))
rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2))

@testset "AbstractMatrix-AbstractMatrix" begin
@testset "n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3)
@testset "Array" begin
rrule_test(*, np, (n₂m), (m₂p))
end

@testset "SubArray - $indexname" for (indexname, m_index) in (
("fast", :), ("slow", Ref(m:-1:1))
)
rrule_test(*, np, view.(n₂m, :, m_index), view.(m₂p, m_index, :))
rrule_test(*, np, n₂m, view.(m₂p, m_index, :))
rrule_test(*, np, view.(n₂m, :, m_index), m₂p)
end

@testset "Adjoints and Transposes" begin
rrule_test(*, np, Transpose.(m₂n), Transpose.(p₂m))
rrule_test(*, np, Adjoint.(m₂n), Adjoint.(p₂m))

rrule_test(*, np, Transpose.(m₂n), (m₂p))
rrule_test(*, np, Adjoint.(m₂n), (m₂p))

rrule_test(*, np, (n₂m), Transpose.(p₂m))
rrule_test(*, np, (n₂m), Adjoint.(p₂m))
end
end
end
end


@testset "$f" for f in (/, \)
@testset "Matrix" begin
for n in 3:5, m in 3:5
Expand Down

0 comments on commit ff29066

Please sign in to comment.