Skip to content

Commit

Permalink
Merge c5cd12b into cbf3eb7
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 1, 2020
2 parents cbf3eb7 + c5cd12b commit a6fce9d
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 210 deletions.
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include("rulesets/Base/base.jl")
include("rulesets/Base/fastmath_able.jl")
include("rulesets/Base/evalpoly.jl")
include("rulesets/Base/array.jl")
include("rulesets/Base/arraymath.jl")
include("rulesets/Base/mapreduce.jl")

include("rulesets/Statistics/statistics.jl")
Expand Down
106 changes: 106 additions & 0 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
######
###### `inv`
######

function frule((_, Δx), ::typeof(inv), x::AbstractArray)
Ω = inv(x)
return Ω, -Ω * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NO_FIELDS, -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end

#####
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B)))
end
return A * B, times_pullback
end



#####
##### `/`
#####

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Aᵀ, dA_pb = rrule(adjoint, A)
Bᵀ, dB_pb = rrule(adjoint, B)
Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ)
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
end
return C, slash_pullback
end

#####
##### `\`
#####

function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Y = A \ B
function backslash_pullback(Ȳ)
∂A = @thunk begin
= A' \
= -* Y'
_add!(Ā, (B - A * Y) *' / A')
_add!(Ā, A' \ Y * (Ȳ' -'A))
end
∂B = @thunk A' \
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback

end

#####
##### `\`, `/` matrix-scalar_rule
#####

function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real)
Y = A/b
function slash_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
end
return Y, slash_pullback
end

function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
Y = b\A
function backslash_pullback(Ȳ)
return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
end
return Y, backslash_pullback
end
132 changes: 0 additions & 132 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
using LinearAlgebra: AbstractTriangular

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}

#####
##### `dot`
#####
Expand Down Expand Up @@ -36,22 +30,6 @@ function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:R
return Ω, cross_pullback
end

#####
##### `inv`
#####

function frule((_, Δx), ::typeof(inv), x::AbstractArray)
Ω = inv(x)
return Ω, -Ω * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NO_FIELDS, -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end

#####
##### `det`
Expand Down Expand Up @@ -138,51 +116,6 @@ function rrule(::typeof(tr), x)
end


#####
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), A::Real, B::AbstractArray{<:Real})
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(dot(Ȳ, B)), @thunk(A * Ȳ))
end
return A * B, times_pullback
end

function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real)
function times_pullback(Ȳ)
return (NO_FIELDS, @thunk(A * Ȳ), @thunk(dot(Ȳ, B)))
end
return A * B, times_pullback
end

#####
##### `\`, `/` matrix-scalar_rule

function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real)
Y = A/b
function slash_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
end
return Y, slash_pullback
end

function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
Y = b\A
function backslash_pullback(Ȳ)
return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
end
return Y, backslash_pullback
end


#####
##### `pinv`
#####
Expand Down Expand Up @@ -278,71 +211,6 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
return Y, pinv_pullback
end

#####
##### `/`
#####

function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real}
Y = A / B
function slash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk/ B'
∂B = @thunk S(-Y' * (Ȳ / B'))
return (NO_FIELDS, ∂A, ∂B)
end
return Y, slash_pullback
end

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Aᵀ, dA_pb = rrule(adjoint, A)
Bᵀ, dB_pb = rrule(adjoint, B)
Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ)
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
end
return C, slash_pullback
end

#####
##### `\`
#####

function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real}
Y = A \ B
function backslash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk S(-(A' \ Ȳ) * Y')
∂B = @thunk A' \
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback
end

function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Y = A \ B
function backslash_pullback(Ȳ)
∂A = @thunk begin
= A' \
= -* Y'
_add!(Ā, (B - A * Y) *' / A')
_add!(Ā, A' \ Y * (Ȳ' -'A))
end
∂B = @thunk A' \
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback

end

#####
##### `norm`
#####
Expand Down
26 changes: 26 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
# Structured matrices
using LinearAlgebra: AbstractTriangular

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
const SquareMatrix{T} = Union{Diagonal{T}, AbstractTriangular{T}}

function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real}
Y = A / B
function slash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk/ B'
∂B = @thunk S(-Y' * (Ȳ / B'))
return (NO_FIELDS, ∂A, ∂B)
end
return Y, slash_pullback
end

function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real}
Y = A \ B
function backslash_pullback(Ȳ)
S = T.name.wrapper
∂A = @thunk S(-(A' \ Ȳ) * Y')
∂B = @thunk A' \
return NO_FIELDS, ∂A, ∂B
end
return Y, backslash_pullback
end

#####
##### `Diagonal`
Expand Down
66 changes: 66 additions & 0 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
@testset "arraymath" begin
@testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64)
N = 3
B = generate_well_conditioned_matrix(T, N)
frule_test(inv, (B, randn(T, N, N)))
rrule_test(inv, randn(T, N, N), (B, randn(T, N, N)))
end

@testset "*" begin
@testset "Matrix-Matrix" begin
dims = [3,4,5]
for n in dims, m in dims, p in dims
# don't need to test square case multiple times
n > 3 && n == m == p && continue
A = randn(m, n)
B = randn(n, p)
= randn(m, p)
rrule_test(*, Ȳ, (A, randn(m, n)), (B, randn(n, p)))
end
end
@testset "Scalar-AbstractArray" begin
for dims in ((3,), (5,4), (10,10), (2,3,4), (2,3,4,5))
rrule_test(*, randn(dims), (1.5, 4.2), (randn(dims), randn(dims)))
rrule_test(*, randn(dims), (randn(dims), randn(dims)), (1.5, 4.2))
end
end
end

@testset "$f" for f in (/, \)
@testset "Matrix" begin
for n in 3:5, m in 3:5
A = randn(m, n)
B = randn(m, n)
= randn(size(f(A, B)))
rrule_test(f, Ȳ, (A, randn(m, n)), (B, randn(m, n)))
end
end
@testset "Vector" begin
x = randn(10)
y = randn(10)
= randn(size(f(x, y))...)
rrule_test(f, ȳ, (x, randn(10)), (y, randn(10)))
end
if f == (\)
@testset "Matrix $f Vector" begin
X = randn(10, 4)
y = randn(10)
= randn(size(f(X, y))...)
rrule_test(f, ȳ, (X, randn(size(X))), (y, randn(10)))
end
@testset "Vector $f Matrix" begin
x = randn(10)
Y = randn(10, 4)
= randn(size(f(x, Y))...)
rrule_test(f, ȳ, (x, randn(size(x))), (Y, randn(size(Y))))
end
end
end
@testset "/ and \\ Scalar-AbstractArray" begin
A = randn(3, 4, 5)
= randn(3, 4, 5)
= randn(3, 4, 5)
rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3))
rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā))
end
end
Loading

0 comments on commit a6fce9d

Please sign in to comment.