Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Base's Array math into the Base ruleset folder #271

Merged
merged 3 commits into from
Oct 1, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include("rulesets/Base/base.jl")
include("rulesets/Base/fastmath_able.jl")
include("rulesets/Base/evalpoly.jl")
include("rulesets/Base/array.jl")
include("rulesets/Base/arraymath.jl")
include("rulesets/Base/mapreduce.jl")

include("rulesets/Statistics/statistics.jl")
Expand Down
128 changes: 128 additions & 0 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
######
###### `inv`
######

function frule((_, Δx), ::typeof(inv), x::AbstractArray)
Ω = inv(x)
return Ω, -Ω * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NO_FIELDS, -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end

#####
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B)))
end
return A * B, times_pullback
end



#####
##### `/`
#####

function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real}
Y = A / B
function slash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk Ȳ / B'
∂B = @thunk S(-Y' * (Ȳ / B'))
return (NO_FIELDS, ∂A, ∂B)
end
return Y, slash_pullback
end

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Aᵀ, dA_pb = rrule(adjoint, A)
Bᵀ, dB_pb = rrule(adjoint, B)
Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ)
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
end
return C, slash_pullback
end

#####
##### `\`
#####

function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real}
Y = A \ B
function backslash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk S(-(A' \ Ȳ) * Y')
∂B = @thunk A' \ Ȳ
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback
end

function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Y = A \ B
function backslash_pullback(Ȳ)
∂A = @thunk begin
B̄ = A' \ Ȳ
Ā = -B̄ * Y'
_add!(Ā, (B - A * Y) * B̄' / A')
_add!(Ā, A' \ Y * (Ȳ' - B̄'A))
end
∂B = @thunk A' \ Ȳ
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback

end

#####
##### `\`, `/` matrix-scalar_rule
#####

function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real)
Y = A/b
function slash_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
end
return Y, slash_pullback
end

function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
Y = b\A
function backslash_pullback(Ȳ)
return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
end
return Y, backslash_pullback
end
126 changes: 0 additions & 126 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,6 @@ function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:R
return Ω, cross_pullback
end

#####
##### `inv`
#####

function frule((_, Δx), ::typeof(inv), x::AbstractArray)
Ω = inv(x)
return Ω, -Ω * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NO_FIELDS, -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end

#####
##### `det`
Expand Down Expand Up @@ -138,51 +122,6 @@ function rrule(::typeof(tr), x)
end


#####
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B)))
end
return A * B, times_pullback
end

#####
##### `\`, `/` matrix-scalar_rule

function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real)
Y = A/b
function slash_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
end
return Y, slash_pullback
end

function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
Y = b\A
function backslash_pullback(Ȳ)
return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
end
return Y, backslash_pullback
end


#####
##### `pinv`
#####
Expand Down Expand Up @@ -278,71 +217,6 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
return Y, pinv_pullback
end

#####
##### `/`
#####

function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real}
Y = A / B
function slash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk Ȳ / B'
∂B = @thunk S(-Y' * (Ȳ / B'))
return (NO_FIELDS, ∂A, ∂B)
end
return Y, slash_pullback
end

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Aᵀ, dA_pb = rrule(adjoint, A)
Bᵀ, dB_pb = rrule(adjoint, B)
Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ)
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
end
return C, slash_pullback
end

#####
##### `\`
#####

function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real}
Y = A \ B
function backslash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk S(-(A' \ Ȳ) * Y')
∂B = @thunk A' \ Ȳ
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback
end

function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Y = A \ B
function backslash_pullback(Ȳ)
∂A = @thunk begin
B̄ = A' \ Ȳ
Ā = -B̄ * Y'
_add!(Ā, (B - A * Y) * B̄' / A')
_add!(Ā, A' \ Y * (Ȳ' - B̄'A))
end
∂B = @thunk A' \ Ȳ
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback

end

#####
##### `norm`
#####
Expand Down
82 changes: 82 additions & 0 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
@testset "arraymath" begin
@testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64)
N = 3
B = generate_well_conditioned_matrix(T, N)
frule_test(inv, (B, randn(T, N, N)))
rrule_test(inv, randn(T, N, N), (B, randn(T, N, N)))
end

@testset "*" begin
@testset "Matrix-Matrix" begin
dims = [3,4,5]
for n in dims, m in dims, p in dims
# don't need to test square case multiple times
n > 3 && n == m == p && continue
A = randn(m, n)
B = randn(n, p)
Ȳ = randn(m, p)
rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p)))
end
end
@testset "Scalar-AbstractArray" begin
for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5))
rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims)))
rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2))
end
end
end

@testset "$f" for f in [/, \]
@testset "Matrix" begin
for n in 3:5, m in 3:5
A = randn(m, n)
B = randn(m, n)
Ȳ = randn(size(f(A, B)))
rrule_test(f, Ȳ, (A, randn(m, n)), (B, randn(m, n)))
end
end
@testset "Vector" begin
x = randn(10)
y = randn(10)
ȳ = randn(size(f(x, y))...)
rrule_test(f, ȳ, (x, randn(10)), (y, randn(10)))
end
if f == (/)
@testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
RHS = T(randn(T == Diagonal ? 10 : (10, 10)))
Y = randn(5, 10)
Ȳ = randn(size(f(Y, RHS))...)
rrule_test(f, Ȳ, (Y, randn(size(Y))), (RHS, randn(size(RHS))))
end
else
@testset "$T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
LHS = T(randn(T == Diagonal ? 10 : (10, 10)))
y = randn(10)
ȳ = randn(size(f(LHS, y))...)
rrule_test(f, ȳ, (LHS, randn(size(LHS))), (y, randn(10)))
Y = randn(10, 10)
Ȳ = randn(10, 10)
rrule_test(f, Ȳ, (LHS, randn(size(LHS))), (Y, randn(size(Y))))
end
@testset "Matrix $f Vector" begin
X = randn(10, 4)
y = randn(10)
ȳ = randn(size(f(X, y))...)
rrule_test(f, ȳ, (X, randn(size(X))), (y, randn(10)))
end
@testset "Vector $f Matrix" begin
x = randn(10)
Y = randn(10, 4)
ȳ = randn(size(f(x, Y))...)
rrule_test(f, ȳ, (x, randn(size(x))), (Y, randn(size(Y))))
end
end
end
@testset "/ and \\ Scalar-AbstractArray" begin
A = randn(3, 4, 5)
Ā = randn(3, 4, 5)
Ȳ = randn(3, 4, 5)
rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3))
rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā))
end
end
Loading