Skip to content

Commit

Permalink
det and logdet for structure matrixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 13, 2020
1 parent e2cfd58 commit cf5766f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,24 @@ function rrule(::typeof(tril), A::AbstractMatrix)
end
return tril(A), tril_pullback
end

_diag_view(X) = view(X, diagind(X))
_diag_view(X::Diagonal) = parent(X) #Diagonal wraps a Vector of just Diagonal elements

function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular})
y = det(X)
s = y ./ _diag_view(X)
function det_pullback(ȳ)
return (NO_FIELDS, Diagonal(ȳ .* s))
end
return y, det_pullback
end

function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
y = logdet(X)
s = one(eltype(X)) ./ _diag_view(X)
function logdet_pullback(ȳ)
return (NO_FIELDS, Diagonal(ȳ .* s))
end
return y, logdet_pullback
end
16 changes: 16 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,20 @@
rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing))
end
end

@testset "det and logdet $T" for T in (Diagonal, UpperTriangular, LowerTriangular)
n = 5
# rand (not randn) so det will be postive, so logdet will be defined
X = T(rand(n, n))
X̄_acc = Diagonal(rand(n, n)) # sensitivity is always a diagonal for these types
@testset "$op" for op in (det, logdet)
rrule_test(op, 2.7, (X, X̄_acc))

@testset "return type" begin
_, op_pullback = rrule(op, X)
= op_pullback(2.7)[2]
@testisa Diagonal
end
end
end
end

0 comments on commit cf5766f

Please sign in to comment.