Skip to content

Commit

Permalink
Merge 51e52e8 into 0e551fb
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jun 11, 2019
2 parents 0e551fb + 51e52e8 commit 7904998
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ include("rules.jl")
include("rules/base.jl")
include("rules/array.jl")
include("rules/broadcast.jl")
include("rules/linalg/utils.jl")
include("rules/linalg/blas.jl")
include("rules/linalg/dense.jl")
include("rules/linalg/diagonal.jl")
include("rules/linalg/symmetric.jl")
include("rules/linalg/structured.jl")
include("rules/linalg/factorization.jl")
include("rules/blas.jl")
include("rules/nanmath.jl")
include("rules/specialfunctions.jl")

Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions src/rules/linalg/diagonal.jl

This file was deleted.

27 changes: 0 additions & 27 deletions src/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,6 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra
return
end

function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

function _eyesubx!(X::AbstractMatrix{T}) where T<:Real
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end

#####
##### `cholesky`
#####
Expand Down
43 changes: 43 additions & 0 deletions src/rules/linalg/structured.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Structured matrices

#####
##### `Diagonal`
#####

rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag)

rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)

#####
##### `Symmetric`
#####

rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ

#####
##### `Adjoint`
#####

# TODO: Deal with complex-valued arrays as well
rrule(::Type{<:Adjoint}, A::AbstractVecOrMat{<:Real}) = Adjoint(A), Rule(adjoint)

rrule(::typeof(adjoint), A::AbstractVecOrMat{<:Real}) = adjoint(A), Rule(adjoint)

#####
##### `Transpose`
#####

rrule(::Type{<:Transpose}, A::AbstractVecOrMat) = Transpose(A), Rule(transpose)

rrule(::typeof(transpose), A::AbstractVecOrMat) = transpose(A), Rule(transpose)

#####
##### Triangular matrices
#####

rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix)

rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix)
4 changes: 0 additions & 4 deletions src/rules/linalg/symmetric.jl

This file was deleted.

32 changes: 32 additions & 0 deletions src/rules/linalg/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Some utility functions for optimizing linear algebra operations that aren't specific
# to any particular rule definition

# F .* (X - X'), overwrites X
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

# I - X, overwrites X
function _eyesubx!(X::AbstractMatrix)
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

# X + Y, overwrites X
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end

0 comments on commit 7904998

Please sign in to comment.