diff --git a/Project.toml b/Project.toml index 22c18a54b..96ed5692d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.7.27" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -21,7 +22,6 @@ julia = "1" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -29,4 +29,4 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesTestUtils", "Compat", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 1e0989b40..279291ab3 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -4,6 +4,7 @@ using Reexport @reexport using ChainRulesCore using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable +using Compat using LinearAlgebra using LinearAlgebra.BLAS using Random diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index da81a9b61..fd2457f01 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -19,23 +19,47 @@ end ##### `*` ##### -function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) +function rrule(::typeof(*), A::AbstractMatrix{<:Number}, B::AbstractMatrix{<:Number}) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) + return ( + NO_FIELDS, + InplaceableThunk( + @thunk(Ȳ * B'), + X̄ -> mul!(X̄, Ȳ, B', true, true) + ), + InplaceableThunk( + @thunk(A' * Ȳ), + X̄ -> mul!(X̄, A', Ȳ, true, true) + ) + ) end return A * B, times_pullback end -function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real}) +function rrule(::typeof(*), A::Number, B::AbstractArray{<:Number}) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ)) + return ( + NO_FIELDS, + @thunk(dot(Ȳ, B)'), + InplaceableThunk( + @thunk(A' * Ȳ), + X̄ -> mul!(X̄, conj(A), Ȳ, true, true) + ) + ) end return A * B, times_pullback end -function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real) +function rrule(::typeof(*), B::AbstractArray{<:Number}, A::Number) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B))) + return ( + NO_FIELDS, + InplaceableThunk( + @thunk(A' * Ȳ), + X̄ -> mul!(X̄, conj(A), Ȳ, true, true) + ), + @thunk(dot(Ȳ, B)'), + ) end return A * B, times_pullback end diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index e12060966..1ec987d36 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -6,26 +6,50 @@ rrule_test(inv, randn(T, N, N), (B, randn(T, N, N))) end - @testset "*" begin - @testset "Matrix-Matrix" begin - dims = [3,4,5] - for n in dims, m in dims, p in dims - # don't need to test square case multiple times - n > 3 && n == m == p && continue - A = randn(m, n) - B = randn(n, p) - Ȳ = randn(m, p) - rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p))) - end + # TODO Test on Complex once we have Complex support for all of these. + @testset "*: $T" for T in (Float64, ComplexF64) + ⋆(a) = round.(5*randn(T, a)) # Helper to generate nice random values + ⋆(a, b) = ⋆((a, b)) # matrix + ⋆() = only(⋆(())) # scalar + + ⋆₂(a) = (⋆(a), ⋆(a)) # Helper to generate random matrix and its cotangent + ⋆₂(a, b) = ⋆₂((a, b)) #matrix + ⋆₂() = ⋆₂(()) # scalar + + @testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5)) + rrule_test(*, ⋆(dims), ⋆₂(), ⋆₂(dims)) + rrule_test(*, ⋆(dims), ⋆₂(dims), ⋆₂()) end - @testset "Scalar-AbstractArray" begin - for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5)) - rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims))) - rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2)) + + @testset "AbstractMatrix-AbstractMatrix" begin + @testset "n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3) + @testset "Array" begin + rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) + end + + @testset "SubArray - $indexname" for (indexname, m_index) in ( + ("fast", :), ("slow", Ref(m:-1:1)) + ) + rrule_test(*, n⋆p, view.(n⋆₂m, :, m_index), view.(m⋆₂p, m_index, :)) + rrule_test(*, n⋆p, n⋆₂m, view.(m⋆₂p, m_index, :)) + rrule_test(*, n⋆p, view.(n⋆₂m, :, m_index), m⋆₂p) + end + + @testset "Adjoints and Transposes" begin + rrule_test(*, n⋆p, Transpose.(m⋆₂n), Transpose.(p⋆₂m)) + rrule_test(*, n⋆p, Adjoint.(m⋆₂n), Adjoint.(p⋆₂m)) + + rrule_test(*, n⋆p, Transpose.(m⋆₂n), (m⋆₂p)) + rrule_test(*, n⋆p, Adjoint.(m⋆₂n), (m⋆₂p)) + + rrule_test(*, n⋆p, (n⋆₂m), Transpose.(p⋆₂m)) + rrule_test(*, n⋆p, (n⋆₂m), Adjoint.(p⋆₂m)) + end end end end + @testset "$f" for f in (/, \) @testset "Matrix" begin for n in 3:5, m in 3:5