Skip to content

Commit

Permalink
use add!! and rename _mulsubtrans!!
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 5, 2020
1 parent 2be1af2 commit 5c96ffe
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.7"
ChainRulesCore = "0.9.12"
ChainRulesTestUtils = "0.4.2, 0.5"
Compat = "3"
FiniteDifferences = "0.10"
Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
∂A = @thunk begin
= A' \
= -* Y'
_add!(Ā, (B - A * Y) *' / A')
_add!(Ā, A' \ Y * (Ȳ' -'A))
= add!!(Ā, (B - A * Y) *' / A')
= add!!(Ā, A' \ Y * (Ȳ' -'A))
end
∂B = @thunk A' \
Expand Down
16 changes: 8 additions & 8 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ function frule((_, ΔA), ::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where
# contract over the largest dimension
if m n
∂Y = -Y * (ΔA * Y)
_add!(∂Y, (ΔA' - Y * (A * ΔA')) * (Y' * Y)) # (I - Y A) ΔA' Y' Y
_add!(∂Y, Y * (Y' * ΔA') * (I - A * Y)) # Y Y' ΔA' (I - A Y)
∂Y = add!!(∂Y, (ΔA' - Y * (A * ΔA')) * (Y' * Y)) # (I - Y A) ΔA' Y' Y
∂Y = add!!(∂Y, Y * (Y' * ΔA') * (I - A * Y)) # Y Y' ΔA' (I - A Y)
else
∂Y = -(Y * ΔA) * Y
_add!(∂Y, (I - Y * A) * (ΔA' * Y') * Y) # (I - Y A) ΔA' Y' Y
_add!(∂Y, (Y * Y') * (ΔA' - (ΔA' * A) * Y)) # Y Y' ΔA' (I - A Y)
∂Y = add!!(∂Y, (I - Y * A) * (ΔA' * Y') * Y) # (I - Y A) ΔA' Y' Y
∂Y = add!!(∂Y, (Y * Y') * (ΔA' - (ΔA' * A) * Y)) # Y Y' ΔA' (I - A Y)
end
return Y, ∂Y
end
Expand Down Expand Up @@ -199,12 +199,12 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
# contract over the largest dimension
if m n
∂A = (Y' * -ΔY) * Y'
_add!(∂A, (Y' * Y) * (ΔY' - (ΔY' * Y) * A)) # Y' Y ΔY' (I - Y A)
_add!(∂A, (I - A * Y) * (ΔY' * Y) * Y') # (I - A Y) ΔY' Y Y'
∂A = add!!(∂A, (Y' * Y) * (ΔY' - (ΔY' * Y) * A)) # Y' Y ΔY' (I - Y A)
∂A = add!!(∂A, (I - A * Y) * (ΔY' * Y) * Y') # (I - A Y) ΔY' Y Y'
elseif m > n
∂A = Y' * (-ΔY * Y')
_add!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A)
_add!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y'
∂A = add!!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A)
∂A = add!!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y'
end
return (NO_FIELDS, ∂A)
end
Expand Down
14 changes: 7 additions & 7 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
# place functions here are significantly faster than their out-of-place, naively
# implemented counterparts, and allocate no additional memory.
Ut = U'
FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
FVᵀV̄ = _mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ
FUᵀŪ = _mulsubtrans!!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
FVᵀV̄ = _mulsubtrans!!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ

S = Diagonal(s)
=isa AbstractZero ?: Diagonal(s̄)

# TODO: consider using MuladdMacro here
Ā = _add!(U * FUᵀŪ * S, ImUUᵀ */ S)) * Vt
Ā = _add!(Ā, U ** Vt)
Ā = _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))
Ā = add!!(U * FUᵀŪ * S, ImUUᵀ */ S)) * Vt
Ā = add!!(Ā, U ** Vt)
Ā = add!!(Ā, U * add!!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))

return Ā
end
Expand Down
17 changes: 4 additions & 13 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# to any particular rule definition

# F .* (X - X'), overwrites X if possible
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
function _mulsubtrans!!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
Expand All @@ -13,9 +13,9 @@ function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
end
return X
end
_mulsubtrans!(X::AbstractZero, F::AbstractZero) = X
_mulsubtrans!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X
_mulsubtrans!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F
_mulsubtrans!!(X::AbstractZero, F::AbstractZero) = X
_mulsubtrans!!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X
_mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F

# I - X, overwrites X
function _eyesubx!(X::AbstractMatrix)
Expand All @@ -26,13 +26,4 @@ function _eyesubx!(X::AbstractMatrix)
return X
end

# X + Y, overwrites X if possible
function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat)
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
return X
end
_add!(X, Y) = X + Y # handles all `AbstractZero` overloads

_extract_imag(x) = complex(0, imag(x))
3 changes: 1 addition & 2 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@testset "Helper functions" begin
X = randn(10, 10)
Y = randn(10, 10)
@test ChainRules._mulsubtrans!(copy(X), Y) Y .* (X - X')
@test ChainRules._mulsubtrans!!(copy(X), Y) Y .* (X - X')
@test ChainRules._eyesubx!(copy(X)) I - X
@test ChainRules._add!(copy(X), Y) X + Y
end
end
@testset "cholesky" begin
Expand Down

0 comments on commit 5c96ffe

Please sign in to comment.