diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 6b8e0ba5a..af081e950 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -274,3 +274,24 @@ function rrule(::typeof(tril), A::AbstractMatrix) end return tril(A), tril_pullback end + +_diag_view(X) = view(X, diagind(X)) +_diag_view(X::Diagonal) = parent(X) #Diagonal wraps a Vector of just Diagonal elements + +function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) + y = det(X) + s = y ./ _diag_view(X) + function det_pullback(ȳ) + return (NO_FIELDS, Diagonal(ȳ .* s)) + end + return y, det_pullback +end + +function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) + y = logdet(X) + s = one(eltype(X)) ./ _diag_view(X) + function logdet_pullback(ȳ) + return (NO_FIELDS, Diagonal(ȳ .* s)) + end + return y, logdet_pullback +end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index a11d484af..044694c0c 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -156,4 +156,20 @@ rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing)) end end + + @testset "det and logdet $T" for T in (Diagonal, UpperTriangular, LowerTriangular) + n = 5 + # rand (not randn) so det will be postive, so logdet will be defined + X = T(rand(n, n)) + X̄_acc = Diagonal(rand(n, n)) # sensitivity is always a diagonal for these types + @testset "$op" for op in (det, logdet) + rrule_test(op, 2.7, (X, X̄_acc)) + + @testset "return type" begin + _, op_pullback = rrule(op, X) + X̄ = op_pullback(2.7)[2] + @test X̄ isa Diagonal + end + end + end end