diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 07864512a..09dce121c 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -31,6 +31,7 @@ include("rulesets/Base/base.jl") include("rulesets/Base/fastmath_able.jl") include("rulesets/Base/evalpoly.jl") include("rulesets/Base/array.jl") +include("rulesets/Base/arraymath.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Statistics/statistics.jl") diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl new file mode 100644 index 000000000..e0515b167 --- /dev/null +++ b/src/rulesets/Base/arraymath.jl @@ -0,0 +1,106 @@ +###### +###### `inv` +###### + +function frule((_, Δx), ::typeof(inv), x::AbstractArray) + Ω = inv(x) + return Ω, -Ω * Δx * Ω +end + +function rrule(::typeof(inv), x::AbstractArray) + Ω = inv(x) + function inv_pullback(ΔΩ) + return NO_FIELDS, -Ω' * ΔΩ * Ω' + end + return Ω, inv_pullback +end + +##### +##### `*` +##### + +function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) + function times_pullback(Ȳ) + return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) + end + return A * B, times_pullback +end + +function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real}) + function times_pullback(Ȳ) + return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ)) + end + return A * B, times_pullback +end + +function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real) + function times_pullback(Ȳ) + return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B))) + end + return A * B, times_pullback +end + + + +##### +##### `/` +##### + +function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + Aᵀ, dA_pb = rrule(adjoint, A) + Bᵀ, dB_pb = rrule(adjoint, B) + Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) + C, dC_pb = rrule(adjoint, Cᵀ) + function slash_pullback(Ȳ) + # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want + _, dC = dC_pb(Ȳ) + _, dBᵀ, dAᵀ = dS_pb(unthunk(dC)) + + ∂A = last(dA_pb(unthunk(dAᵀ))) + ∂B = last(dA_pb(unthunk(dBᵀ))) + + (NO_FIELDS, ∂A, ∂B) + end + return C, slash_pullback +end + +##### +##### `\` +##### + +function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + Y = A \ B + function backslash_pullback(Ȳ) + ∂A = @thunk begin + B̄ = A' \ Ȳ + Ā = -B̄ * Y' + _add!(Ā, (B - A * Y) * B̄' / A') + _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) + Ā + end + ∂B = @thunk A' \ Ȳ + return NO_FIELDS, ∂A, ∂B + end + return Y, backslash_pullback + +end + +##### +##### `\`, `/` matrix-scalar_rule +##### + +function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real) + Y = A/b + function slash_pullback(Ȳ) + return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) + end + return Y, slash_pullback +end + +function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real}) + Y = b\A + function backslash_pullback(Ȳ) + return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) + end + return Y, backslash_pullback +end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 172ac6c95..7b5953bde 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -1,9 +1,3 @@ -using LinearAlgebra: AbstractTriangular - -# Matrix wrapper types that we know are square and are thus potentially invertible. For -# these we can use simpler definitions for `/` and `\`. -const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} - ##### ##### `dot` ##### @@ -36,22 +30,6 @@ function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:R return Ω, cross_pullback end -##### -##### `inv` -##### - -function frule((_, Δx), ::typeof(inv), x::AbstractArray) - Ω = inv(x) - return Ω, -Ω * Δx * Ω -end - -function rrule(::typeof(inv), x::AbstractArray) - Ω = inv(x) - function inv_pullback(ΔΩ) - return NO_FIELDS, -Ω' * ΔΩ * Ω' - end - return Ω, inv_pullback -end ##### ##### `det` @@ -138,51 +116,6 @@ function rrule(::typeof(tr), x) end -##### -##### `*` -##### - -function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) - function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) - end - return A * B, times_pullback -end - -function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real}) - function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ)) - end - return A * B, times_pullback -end - -function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real) - function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B))) - end - return A * B, times_pullback -end - -##### -##### `\`, `/` matrix-scalar_rule - -function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real) - Y = A/b - function slash_pullback(Ȳ) - return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) - end - return Y, slash_pullback -end - -function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real}) - Y = b\A - function backslash_pullback(Ȳ) - return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) - end - return Y, backslash_pullback -end - - ##### ##### `pinv` ##### @@ -278,71 +211,6 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T} return Y, pinv_pullback end -##### -##### `/` -##### - -function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} - Y = A / B - function slash_pullback(Ȳ) - S = T.name.wrapper - ∂A = @thunk Ȳ / B' - ∂B = @thunk S(-Y' * (Ȳ / B')) - return (NO_FIELDS, ∂A, ∂B) - end - return Y, slash_pullback -end - -function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Aᵀ, dA_pb = rrule(adjoint, A) - Bᵀ, dB_pb = rrule(adjoint, B) - Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) - C, dC_pb = rrule(adjoint, Cᵀ) - function slash_pullback(Ȳ) - # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want - _, dC = dC_pb(Ȳ) - _, dBᵀ, dAᵀ = dS_pb(unthunk(dC)) - - ∂A = last(dA_pb(unthunk(dAᵀ))) - ∂B = last(dA_pb(unthunk(dBᵀ))) - - (NO_FIELDS, ∂A, ∂B) - end - return C, slash_pullback -end - -##### -##### `\` -##### - -function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} - Y = A \ B - function backslash_pullback(Ȳ) - S = T.name.wrapper - ∂A = @thunk S(-(A' \ Ȳ) * Y') - ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B - end - return Y, backslash_pullback -end - -function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Y = A \ B - function backslash_pullback(Ȳ) - ∂A = @thunk begin - B̄ = A' \ Ȳ - Ā = -B̄ * Y' - _add!(Ā, (B - A * Y) * B̄' / A') - _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) - Ā - end - ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B - end - return Y, backslash_pullback - -end - ##### ##### `norm` ##### diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index bbc5e8743..6b8e0ba5a 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -1,5 +1,31 @@ # Structured matrices +using LinearAlgebra: AbstractTriangular + +# Matrix wrapper types that we know are square and are thus potentially invertible. For +# these we can use simpler definitions for `/` and `\`. +const SquareMatrix{T} = Union{Diagonal{T}, AbstractTriangular{T}} + +function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} + Y = A / B + function slash_pullback(Ȳ) + S = T.name.wrapper + ∂A = @thunk Ȳ / B' + ∂B = @thunk S(-Y' * (Ȳ / B')) + return (NO_FIELDS, ∂A, ∂B) + end + return Y, slash_pullback +end +function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} + Y = A \ B + function backslash_pullback(Ȳ) + S = T.name.wrapper + ∂A = @thunk S(-(A' \ Ȳ) * Y') + ∂B = @thunk A' \ Ȳ + return NO_FIELDS, ∂A, ∂B + end + return Y, backslash_pullback +end ##### ##### `Diagonal` diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl new file mode 100644 index 000000000..adcb6f766 --- /dev/null +++ b/test/rulesets/Base/arraymath.jl @@ -0,0 +1,66 @@ +@testset "arraymath" begin + @testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64) + N = 3 + B = generate_well_conditioned_matrix(T, N) + frule_test(inv, (B, randn(T, N, N))) + rrule_test(inv, randn(T, N, N), (B, randn(T, N, N))) + end + + @testset "*" begin + @testset "Matrix-Matrix" begin + dims = [3,4,5] + for n in dims, m in dims, p in dims + # don't need to test square case multiple times + n > 3 && n == m == p && continue + A = randn(m, n) + B = randn(n, p) + Ȳ = randn(m, p) + rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p))) + end + end + @testset "Scalar-AbstractArray" begin + for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5)) + rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims))) + rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2)) + end + end + end + + @testset "$f" for f in (/, \) + @testset "Matrix" begin + for n in 3:5, m in 3:5 + A = randn(m, n) + B = randn(m, n) + Ȳ = randn(size(f(A, B))) + rrule_test(f, Ȳ, (A, randn(m, n)), (B, randn(m, n))) + end + end + @testset "Vector" begin + x = randn(10) + y = randn(10) + ȳ = randn(size(f(x, y))...) + rrule_test(f, ȳ, (x, randn(10)), (y, randn(10))) + end + if f == (\) + @testset "Matrix $f Vector" begin + X = randn(10, 4) + y = randn(10) + ȳ = randn(size(f(X, y))...) + rrule_test(f, ȳ, (X, randn(size(X))), (y, randn(10))) + end + @testset "Vector $f Matrix" begin + x = randn(10) + Y = randn(10, 4) + ȳ = randn(size(f(x, Y))...) + rrule_test(f, ȳ, (x, randn(size(x))), (Y, randn(size(Y)))) + end + end + end + @testset "/ and \\ Scalar-AbstractArray" begin + A = randn(3, 4, 5) + Ā = randn(3, 4, 5) + Ȳ = randn(3, 4, 5) + rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3)) + rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā)) + end +end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 719fabad3..4a3517e4a 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -1,4 +1,4 @@ -@testset "linalg" begin +@testset "dense" begin @testset "dot" begin @testset "Vector{$T}" for T in (Float64, ComplexF64) M = 3 @@ -42,12 +42,6 @@ rrule_test(cross, ΔΩ, (x, x̄), (y, ȳ)) end end - @testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64) - N = 3 - B = generate_well_conditioned_matrix(T, N) - frule_test(inv, (B, randn(T, N, N))) - rrule_test(inv, randn(T, N, N), (B, randn(T, N, N))) - end @testset "pinv" begin @testset "$T" for T in (Float64, ComplexF64) test_scalar(pinv, randn(T)) @@ -114,77 +108,6 @@ frule_test(tr, (randn(N, N), randn(N, N))) rrule_test(tr, randn(), (randn(N, N), randn(N, N))) end - @testset "*" begin - @testset "Matrix-Matrix" begin - dims = [3,4,5] - for n in dims, m in dims, p in dims - n > 3 && n == m == p && continue # don't need to test square case multiple times - A = randn(m, n) - B = randn(n, p) - Ȳ = randn(m, p) - rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p))) - end - end - @testset "Scalar-AbstractArray" begin - for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5)) - rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims))) - rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2)) - end - end - end - @testset "$f" for f in [/, \] - @testset "Matrix" begin - for n in 3:5, m in 3:5 - A = randn(m, n) - B = randn(m, n) - Ȳ = randn(size(f(A, B))) - rrule_test(f, Ȳ, (A, randn(m, n)), (B, randn(m, n))) - end - end - @testset "Vector" begin - x = randn(10) - y = randn(10) - ȳ = randn(size(f(x, y))...) - rrule_test(f, ȳ, (x, randn(10)), (y, randn(10))) - end - if f == (/) - @testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular) - RHS = T(randn(T == Diagonal ? 10 : (10, 10))) - Y = randn(5, 10) - Ȳ = randn(size(f(Y, RHS))...) - rrule_test(f, Ȳ, (Y, randn(size(Y))), (RHS, randn(size(RHS)))) - end - else - @testset "$T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular) - LHS = T(randn(T == Diagonal ? 10 : (10, 10))) - y = randn(10) - ȳ = randn(size(f(LHS, y))...) - rrule_test(f, ȳ, (LHS, randn(size(LHS))), (y, randn(10))) - Y = randn(10, 10) - Ȳ = randn(10, 10) - rrule_test(f, Ȳ, (LHS, randn(size(LHS))), (Y, randn(size(Y)))) - end - @testset "Matrix $f Vector" begin - X = randn(10, 4) - y = randn(10) - ȳ = randn(size(f(X, y))...) - rrule_test(f, ȳ, (X, randn(size(X))), (y, randn(10))) - end - @testset "Vector $f Matrix" begin - x = randn(10) - Y = randn(10, 4) - ȳ = randn(size(f(x, Y))...) - rrule_test(f, ȳ, (x, randn(size(x))), (Y, randn(size(Y)))) - end - end - end - @testset "/ and \\ Scalar-AbstractArray" begin - A = randn(3, 4, 5) - Ā = randn(3, 4, 5) - Ȳ = randn(3, 4, 5) - rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3)) - rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā)) - end @testset "norm" begin for dims in [(), (5,), (3, 2), (7, 3, 2)] A = randn(dims...) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 8ca1f0284..a11d484af 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -1,4 +1,23 @@ @testset "Structured Matrices" begin + @testset "/ and \\ on Square Matrixes" begin + @testset "//, $T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular) + RHS = T(randn(T == Diagonal ? 10 : (10, 10))) + Y = randn(5, 10) + Ȳ = randn(size(/(Y, RHS))...) + rrule_test(/, Ȳ, (Y, randn(size(Y))), (RHS, randn(size(RHS)))) + end + + @testset "\\ $T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular) + LHS = T(randn(T == Diagonal ? 10 : (10, 10))) + y = randn(10) + ȳ = randn(size(\(LHS, y))...) + rrule_test(\, ȳ, (LHS, randn(size(LHS))), (y, randn(10))) + Y = randn(10, 10) + Ȳ = randn(10, 10) + rrule_test(\, Ȳ, (LHS, randn(size(LHS))), (Y, randn(size(Y)))) + end + end + @testset "Diagonal" begin N = 3 rrule_test(Diagonal, randn(N, N), (randn(N), randn(N))) diff --git a/test/runtests.jl b/test/runtests.jl index 740d87bca..154306904 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,7 @@ println("Testing ChainRules.jl") include(joinpath("rulesets", "Base", "fastmath_able.jl")) include(joinpath("rulesets", "Base", "evalpoly.jl")) include(joinpath("rulesets", "Base", "array.jl")) + include(joinpath("rulesets", "Base", "arraymath.jl")) include(joinpath("rulesets", "Base", "mapreduce.jl")) end