From 032213f4165d87a222ad881036d68aa02e3aa878 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Oct 2020 20:02:15 +0100 Subject: [PATCH 01/10] WIP: Optimized Strided MatMul --- src/ChainRules.jl | 1 + src/rulesets/LinearAlgebra/strided.jl | 37 ++++++++++++++++++++++++++ test/rulesets/LinearAlgebra/strided.jl | 20 ++++++++++++++ test/runtests.jl | 1 + 4 files changed, 59 insertions(+) create mode 100644 src/rulesets/LinearAlgebra/strided.jl create mode 100644 test/rulesets/LinearAlgebra/strided.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 1e0989b40..daf373e38 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -40,6 +40,7 @@ include("rulesets/Statistics/statistics.jl") include("rulesets/LinearAlgebra/utils.jl") include("rulesets/LinearAlgebra/blas.jl") include("rulesets/LinearAlgebra/dense.jl") +include("rulesets/LinearAlgebra/strided.jl") include("rulesets/LinearAlgebra/structured.jl") include("rulesets/LinearAlgebra/factorization.jl") diff --git a/src/rulesets/LinearAlgebra/strided.jl b/src/rulesets/LinearAlgebra/strided.jl new file mode 100644 index 000000000..6fe3c2ebe --- /dev/null +++ b/src/rulesets/LinearAlgebra/strided.jl @@ -0,0 +1,37 @@ +# Use BLAS.gemm for strided matrix-matrix multiplication sensitivites. +const RS = StridedMatrix{<:Number} +const RST = Transpose{<:Number, <:RS} +const RSA = Adjoint{<:Number, <:RS} + +# Note: weird spacing here is intentional to make this readable as a table +for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in [ + (RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ), + (RST, RS, 'N', 'T', :B, :Ȳ, 'N', 'N', :A, :Ȳ), + (RS, RST, 'N', 'N', :Ȳ, :B, 'T', 'N', :Ȳ, :A), + (RST, RST, 'T', 'T', :B, :Ȳ, 'T', 'T', :Ȳ, :A), + (RSA, RS, 'N', 'C', :B, :Ȳ, 'N', 'N', :A, :Ȳ), + (RS, RSA, 'N', 'N', :Ȳ, :B, 'C', 'N', :Ȳ, :A), + (RSA, RSA, 'C', 'C', :B, :Ȳ, 'C', 'C', :Ȳ, :A), +] + @eval function rrule(::typeof(*), A::$TA, B::$TB) + function strided_matmul_pullback(Ȳ) + @show :A=>($tCA, $tDA, $CA, $DA) + @show :B=>($tCB, $tDB, $CB, $DB) + # TODO: I think we are messing up what is transposed for GEMM + Ā = LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA) + B̄ = LinearAlgebra.BLAS.gemm($tCB, $tDB, $CB, $DB) + #== + Ā = InplaceableThunk( + @thunk(LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA)), + X̄ -> LinearAlgebra.BLAS.gemm!($tCA, $tDA, 1.0, $CA, $DA, 1.0, X̄), + ) + B̄ = InplaceableThunk( + @thunk(LinearAlgebra.BLAS.gemm($tCB, $tDB, $CB, $DB)), + X̄ -> LinearAlgebra.BLAS.gemm!($tCB, $tDB, 1.0, $CB, $DB, 1.0, X̄), + ) + ==# + return (NO_FIELDS, Ā, B̄) + end + return A*B, strided_matmul_pullback + end +end diff --git a/test/rulesets/LinearAlgebra/strided.jl b/test/rulesets/LinearAlgebra/strided.jl new file mode 100644 index 000000000..ed14c859f --- /dev/null +++ b/test/rulesets/LinearAlgebra/strided.jl @@ -0,0 +1,20 @@ +@testset "strided.jl" begin + @testset "Matrix-Matrix" begin + dims = [3]#,4,5] + + ⋆(a, b) = rand(-9.0:9.0, a, b) # Helper to generate random matrix + ⋆₂(a, b) = (a⋆b, a⋆b) # Helper to generate random matrix and its cotangent + @testset "n=$n, m=$m, p=$p" for n in dims, m in dims, p in dims + rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) + + rrule_test(*, n⋆p, Transpose.(m⋆₂n), Transpose.(p⋆₂m)) + rrule_test(*, n⋆p, Adjoint.(m⋆₂n), Adjoint.(p⋆₂m)) + + rrule_test(*, n⋆p, Transpose.(m⋆₂n), (m⋆₂p)) + rrule_test(*, n⋆p, Adjoint.(m⋆₂n), (m⋆₂p)) + + rrule_test(*, n⋆p, (n⋆₂m), Transpose.(p⋆₂m)) + rrule_test(*, n⋆p, (n⋆₂m), Adjoint.(p⋆₂m)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 21aaee18d..d66455935 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,7 @@ println("Testing ChainRules.jl") @testset "LinearAlgebra" begin include(joinpath("rulesets", "LinearAlgebra", "dense.jl")) + include(joinpath("rulesets", "LinearAlgebra", "stided.jl")) include(joinpath("rulesets", "LinearAlgebra", "structured.jl")) include(joinpath("rulesets", "LinearAlgebra", "factorization.jl")) include(joinpath("rulesets", "LinearAlgebra", "blas.jl")) From 8d13fd12e4891d69c232d4d941e170c2726debb5 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Oct 2020 20:12:13 +0100 Subject: [PATCH 02/10] add todos --- src/rulesets/LinearAlgebra/strided.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/strided.jl b/src/rulesets/LinearAlgebra/strided.jl index 6fe3c2ebe..014183ec5 100644 --- a/src/rulesets/LinearAlgebra/strided.jl +++ b/src/rulesets/LinearAlgebra/strided.jl @@ -17,10 +17,10 @@ for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in [ function strided_matmul_pullback(Ȳ) @show :A=>($tCA, $tDA, $CA, $DA) @show :B=>($tCB, $tDB, $CB, $DB) - # TODO: I think we are messing up what is transposed for GEMM + # TODO: for testing purposes just starting with this Ā = LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA) B̄ = LinearAlgebra.BLAS.gemm($tCB, $tDB, $CB, $DB) - #== + #== # TODO uncomment this inplace version Ā = InplaceableThunk( @thunk(LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA)), X̄ -> LinearAlgebra.BLAS.gemm!($tCA, $tDA, 1.0, $CA, $DA, 1.0, X̄), From a4b1aaf31439625b773099657cec4d8a532e2fce Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 8 Oct 2020 11:18:14 +0100 Subject: [PATCH 03/10] use Compat so can have 5 arg mul! --- Project.toml | 4 ++-- src/ChainRules.jl | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 22c18a54b..96ed5692d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.7.27" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -21,7 +22,6 @@ julia = "1" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -29,4 +29,4 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesTestUtils", "Compat", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index daf373e38..323d5c839 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -4,6 +4,7 @@ using Reexport @reexport using ChainRulesCore using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable +using Compat using LinearAlgebra using LinearAlgebra.BLAS using Random From ee51ff7b8f51490f49e9498d9c029dc985d052a4 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 8 Oct 2020 18:47:34 +0100 Subject: [PATCH 04/10] Add tests for various forms of matmul --- test/rulesets/Base/arraymath.jl | 58 ++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index e12060966..07fd6b7ee 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -6,26 +6,54 @@ rrule_test(inv, randn(T, N, N), (B, randn(T, N, N))) end - @testset "*" begin - @testset "Matrix-Matrix" begin - dims = [3,4,5] - for n in dims, m in dims, p in dims - # don't need to test square case multiple times - n > 3 && n == m == p && continue - A = randn(m, n) - B = randn(n, p) - Ȳ = randn(m, p) - rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p))) - end + # TODO Test on Complex once we have Complex support for all of these. + @testset "*: $T" for T in (Float64,) + ⋆(a) = round.(5*randn(T, a)) # Helper to generate nice random values + ⋆(a, b) = ⋆((a, b)) # matrix + ⋆() = only(⋆(())) # scalar + + ⋆₂(a) = (⋆(a), ⋆(a)) # Helper to generate random matrix and its cotangent + ⋆₂(a, b) = ⋆₂((a, b)) #matrix + ⋆₂() = ⋆₂(()) # scalar + + @testset "Scalar-Array $dims" for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5)) + rrule_test(*, ⋆(dims), ⋆₂(), ⋆₂(dims)) + rrule_test(*, ⋆(dims), ⋆₂(dims), ⋆₂()) end - @testset "Scalar-AbstractArray" begin - for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5)) - rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims))) - rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2)) + + @testset "AbstractMatrix-AbstractMatrix" begin + dims = [2, 4, 5, 10] # small matrixes have some special cases + @testset "n=$n, m=$m, p=$p" for n in dims, m in dims, p in dims + @testset "Array" begin + rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) + end + + #== # This breaks FiniteDiff + # TODO uncomment after we have: https://github.com/JuliaDiff/FiniteDifferences.jl/pull/110 + @testset "SubArray - $indexname" for (indexname, m_index) in ( + ("fast", :), ("slow", Ref(m:-1:1)) + ) + rrule_test(*, n⋆p, view.(n⋆₂m, :, m_index), view.(m⋆₂p, m_index, :)) + rrule_test(*, n⋆p, n⋆₂m, view.(m⋆₂p, m_index, :)) + rrule_test(*, n⋆p, view.(n⋆₂m, :, m_index), m⋆₂p) + end + ==# + + @testset "Adjoints and Transposes" begin + rrule_test(*, n⋆p, Transpose.(m⋆₂n), Transpose.(p⋆₂m)) + rrule_test(*, n⋆p, Adjoint.(m⋆₂n), Adjoint.(p⋆₂m)) + + rrule_test(*, n⋆p, Transpose.(m⋆₂n), (m⋆₂p)) + rrule_test(*, n⋆p, Adjoint.(m⋆₂n), (m⋆₂p)) + + rrule_test(*, n⋆p, (n⋆₂m), Transpose.(p⋆₂m)) + rrule_test(*, n⋆p, (n⋆₂m), Adjoint.(p⋆₂m)) + end end end end + @testset "$f" for f in (/, \) @testset "Matrix" begin for n in 3:5, m in 3:5 From d1a93b9b198d9e493eea2b039f3451e174e6cddb Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 8 Oct 2020 19:09:49 +0100 Subject: [PATCH 05/10] Accumulate matmul inplace --- src/ChainRules.jl | 1 - src/rulesets/Base/arraymath.jl | 32 +++++++++++++++++++--- src/rulesets/LinearAlgebra/strided.jl | 37 -------------------------- test/rulesets/LinearAlgebra/strided.jl | 20 -------------- 4 files changed, 28 insertions(+), 62 deletions(-) delete mode 100644 src/rulesets/LinearAlgebra/strided.jl delete mode 100644 test/rulesets/LinearAlgebra/strided.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 323d5c839..279291ab3 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -41,7 +41,6 @@ include("rulesets/Statistics/statistics.jl") include("rulesets/LinearAlgebra/utils.jl") include("rulesets/LinearAlgebra/blas.jl") include("rulesets/LinearAlgebra/dense.jl") -include("rulesets/LinearAlgebra/strided.jl") include("rulesets/LinearAlgebra/structured.jl") include("rulesets/LinearAlgebra/factorization.jl") diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index da81a9b61..eb8cfe519 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -19,23 +19,47 @@ end ##### `*` ##### -function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) +function rrule(::typeof(*), A::AbstractMatrix{<:Number}, B::AbstractMatrix{<:Number}) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) + return ( + NO_FIELDS, + InplaceableThunk( + @thunk(Ȳ * B'), + X̄ -> mul!(X̄, Ȳ, B', true, true) + ), + InplaceableThunk( + @thunk(A' * Ȳ), + X̄ -> mul!(X̄, A', Ȳ, true, true) + ) + ) end return A * B, times_pullback end function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real}) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ)) + return ( + NO_FIELDS, + @thunk(dot(Ȳ, B)), + InplaceableThunk( + @thunk(A * Ȳ), + X̄ -> mul!(X̄, A, Ȳ, true, true) + ) + ) end return A * B, times_pullback end function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B))) + return ( + NO_FIELDS, + InplaceableThunk( + @thunk(A * Ȳ), + X̄ -> mul!(X̄, A, Ȳ, true, true) + ), + @thunk(dot(Ȳ, B)), + ) end return A * B, times_pullback end diff --git a/src/rulesets/LinearAlgebra/strided.jl b/src/rulesets/LinearAlgebra/strided.jl deleted file mode 100644 index 014183ec5..000000000 --- a/src/rulesets/LinearAlgebra/strided.jl +++ /dev/null @@ -1,37 +0,0 @@ -# Use BLAS.gemm for strided matrix-matrix multiplication sensitivites. -const RS = StridedMatrix{<:Number} -const RST = Transpose{<:Number, <:RS} -const RSA = Adjoint{<:Number, <:RS} - -# Note: weird spacing here is intentional to make this readable as a table -for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in [ - (RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ), - (RST, RS, 'N', 'T', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (RS, RST, 'N', 'N', :Ȳ, :B, 'T', 'N', :Ȳ, :A), - (RST, RST, 'T', 'T', :B, :Ȳ, 'T', 'T', :Ȳ, :A), - (RSA, RS, 'N', 'C', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (RS, RSA, 'N', 'N', :Ȳ, :B, 'C', 'N', :Ȳ, :A), - (RSA, RSA, 'C', 'C', :B, :Ȳ, 'C', 'C', :Ȳ, :A), -] - @eval function rrule(::typeof(*), A::$TA, B::$TB) - function strided_matmul_pullback(Ȳ) - @show :A=>($tCA, $tDA, $CA, $DA) - @show :B=>($tCB, $tDB, $CB, $DB) - # TODO: for testing purposes just starting with this - Ā = LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA) - B̄ = LinearAlgebra.BLAS.gemm($tCB, $tDB, $CB, $DB) - #== # TODO uncomment this inplace version - Ā = InplaceableThunk( - @thunk(LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA)), - X̄ -> LinearAlgebra.BLAS.gemm!($tCA, $tDA, 1.0, $CA, $DA, 1.0, X̄), - ) - B̄ = InplaceableThunk( - @thunk(LinearAlgebra.BLAS.gemm($tCB, $tDB, $CB, $DB)), - X̄ -> LinearAlgebra.BLAS.gemm!($tCB, $tDB, 1.0, $CB, $DB, 1.0, X̄), - ) - ==# - return (NO_FIELDS, Ā, B̄) - end - return A*B, strided_matmul_pullback - end -end diff --git a/test/rulesets/LinearAlgebra/strided.jl b/test/rulesets/LinearAlgebra/strided.jl deleted file mode 100644 index ed14c859f..000000000 --- a/test/rulesets/LinearAlgebra/strided.jl +++ /dev/null @@ -1,20 +0,0 @@ -@testset "strided.jl" begin - @testset "Matrix-Matrix" begin - dims = [3]#,4,5] - - ⋆(a, b) = rand(-9.0:9.0, a, b) # Helper to generate random matrix - ⋆₂(a, b) = (a⋆b, a⋆b) # Helper to generate random matrix and its cotangent - @testset "n=$n, m=$m, p=$p" for n in dims, m in dims, p in dims - rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) - - rrule_test(*, n⋆p, Transpose.(m⋆₂n), Transpose.(p⋆₂m)) - rrule_test(*, n⋆p, Adjoint.(m⋆₂n), Adjoint.(p⋆₂m)) - - rrule_test(*, n⋆p, Transpose.(m⋆₂n), (m⋆₂p)) - rrule_test(*, n⋆p, Adjoint.(m⋆₂n), (m⋆₂p)) - - rrule_test(*, n⋆p, (n⋆₂m), Transpose.(p⋆₂m)) - rrule_test(*, n⋆p, (n⋆₂m), Adjoint.(p⋆₂m)) - end - end -end From 4c91f9ba7e726afb325636ca53eb5bcb6cdf0d00 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 8 Oct 2020 19:24:26 +0100 Subject: [PATCH 06/10] Support and test complex matmul and matric scaling --- src/rulesets/Base/arraymath.jl | 12 ++++++------ test/rulesets/Base/arraymath.jl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index eb8cfe519..a6aca3855 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -36,13 +36,13 @@ function rrule(::typeof(*), A::AbstractMatrix{<:Number}, B::AbstractMatrix{<:Num return A * B, times_pullback end -function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real}) +function rrule(::typeof(*), A::Number, B::AbstractArray{<:Number}) function times_pullback(Ȳ) return ( NO_FIELDS, - @thunk(dot(Ȳ, B)), + @thunk(dot(Ȳ, B)'), InplaceableThunk( - @thunk(A * Ȳ), + @thunk(A' * Ȳ), X̄ -> mul!(X̄, A, Ȳ, true, true) ) ) @@ -50,15 +50,15 @@ function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real}) return A * B, times_pullback end -function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real) +function rrule(::typeof(*), B::AbstractArray{<:Number}, A::Number) function times_pullback(Ȳ) return ( NO_FIELDS, InplaceableThunk( - @thunk(A * Ȳ), + @thunk(A' * Ȳ), X̄ -> mul!(X̄, A, Ȳ, true, true) ), - @thunk(dot(Ȳ, B)), + @thunk(dot(Ȳ, B)'), ) end return A * B, times_pullback diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 07fd6b7ee..fc1bc77b9 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -7,7 +7,7 @@ end # TODO Test on Complex once we have Complex support for all of these. - @testset "*: $T" for T in (Float64,) + @testset "*: $T" for T in (Float64, ComplexF64) ⋆(a) = round.(5*randn(T, a)) # Helper to generate nice random values ⋆(a, b) = ⋆((a, b)) # matrix ⋆() = only(⋆(())) # scalar @@ -22,7 +22,7 @@ end @testset "AbstractMatrix-AbstractMatrix" begin - dims = [2, 4, 5, 10] # small matrixes have some special cases + dims = [2, 5, 10] # small matrixes can have some special cases @testset "n=$n, m=$m, p=$p" for n in dims, m in dims, p in dims @testset "Array" begin rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) From 5a5e0f60266d2299db53926f7b583a59fb1eebb7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 15:51:37 +0100 Subject: [PATCH 07/10] renable tests on SubArrays --- test/rulesets/Base/arraymath.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index fc1bc77b9..f10a8629b 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -28,8 +28,6 @@ rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) end - #== # This breaks FiniteDiff - # TODO uncomment after we have: https://github.com/JuliaDiff/FiniteDifferences.jl/pull/110 @testset "SubArray - $indexname" for (indexname, m_index) in ( ("fast", :), ("slow", Ref(m:-1:1)) ) @@ -37,7 +35,6 @@ rrule_test(*, n⋆p, n⋆₂m, view.(m⋆₂p, m_index, :)) rrule_test(*, n⋆p, view.(n⋆₂m, :, m_index), m⋆₂p) end - ==# @testset "Adjoints and Transposes" begin rrule_test(*, n⋆p, Transpose.(m⋆₂n), Transpose.(p⋆₂m)) From 9bf2738cc6e2cc12e762eb33c8cbfabe638a1342 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 19 Oct 2020 15:37:44 +0100 Subject: [PATCH 08/10] cut down some excess in tests --- test/rulesets/Base/arraymath.jl | 5 ++--- test/runtests.jl | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index f10a8629b..1ec987d36 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -16,14 +16,13 @@ ⋆₂(a, b) = ⋆₂((a, b)) #matrix ⋆₂() = ⋆₂(()) # scalar - @testset "Scalar-Array $dims" for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5)) + @testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5)) rrule_test(*, ⋆(dims), ⋆₂(), ⋆₂(dims)) rrule_test(*, ⋆(dims), ⋆₂(dims), ⋆₂()) end @testset "AbstractMatrix-AbstractMatrix" begin - dims = [2, 5, 10] # small matrixes can have some special cases - @testset "n=$n, m=$m, p=$p" for n in dims, m in dims, p in dims + @testset "n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3) @testset "Array" begin rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) end diff --git a/test/runtests.jl b/test/runtests.jl index d66455935..21aaee18d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,7 +38,6 @@ println("Testing ChainRules.jl") @testset "LinearAlgebra" begin include(joinpath("rulesets", "LinearAlgebra", "dense.jl")) - include(joinpath("rulesets", "LinearAlgebra", "stided.jl")) include(joinpath("rulesets", "LinearAlgebra", "structured.jl")) include(joinpath("rulesets", "LinearAlgebra", "factorization.jl")) include(joinpath("rulesets", "LinearAlgebra", "blas.jl")) From 540c127f80e1df2d234dbf78be4ca0dd76ec8be6 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 19 Oct 2020 15:37:59 +0100 Subject: [PATCH 09/10] fix complex scalar * array --- src/rulesets/Base/arraymath.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index a6aca3855..fd2457f01 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -43,7 +43,7 @@ function rrule(::typeof(*), A::Number, B::AbstractArray{<:Number}) @thunk(dot(Ȳ, B)'), InplaceableThunk( @thunk(A' * Ȳ), - X̄ -> mul!(X̄, A, Ȳ, true, true) + X̄ -> mul!(X̄, conj(A), Ȳ, true, true) ) ) end @@ -56,7 +56,7 @@ function rrule(::typeof(*), B::AbstractArray{<:Number}, A::Number) NO_FIELDS, InplaceableThunk( @thunk(A' * Ȳ), - X̄ -> mul!(X̄, A, Ȳ, true, true) + X̄ -> mul!(X̄, conj(A), Ȳ, true, true) ), @thunk(dot(Ȳ, B)'), ) From 4a459e69d4ef97375ed1fabd9d2accf621237255 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 19 Oct 2020 22:38:51 +0100 Subject: [PATCH 10/10] Update test/rulesets/Base/arraymath.jl --- test/rulesets/Base/arraymath.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 1ec987d36..8bc996860 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -6,7 +6,6 @@ rrule_test(inv, randn(T, N, N), (B, randn(T, N, N))) end - # TODO Test on Complex once we have Complex support for all of these. @testset "*: $T" for T in (Float64, ComplexF64) ⋆(a) = round.(5*randn(T, a)) # Helper to generate nice random values ⋆(a, b) = ⋆((a, b)) # matrix