diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 6b8e0ba5a..c7a43fd7a 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -274,3 +274,24 @@ function rrule(::typeof(tril), A::AbstractMatrix) end return tril(A), tril_pullback end + +_diag_view(X) = view(X, diagind(X)) +_diag_view(X::Diagonal) = parent(X) #Diagonal wraps a Vector of just Diagonal elements + +function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) + y = det(X) + s = conj!(y ./ _diag_view(X)) + function det_pullback(ȳ) + return (NO_FIELDS, Diagonal(ȳ .* s)) + end + return y, det_pullback +end + +function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) + y = logdet(X) + s = conj!(one(eltype(X)) ./ _diag_view(X)) + function logdet_pullback(ȳ) + return (NO_FIELDS, Diagonal(ȳ .* s)) + end + return y, logdet_pullback +end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index a11d484af..695266835 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -156,4 +156,22 @@ rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing)) end end + + @testset "det and logdet $S" for S in (Diagonal, UpperTriangular, LowerTriangular) + @testset "$op" for op in (det, logdet) + @testset "$T" for T in (Float64, ComplexF64) + n = 5 + # rand (not randn) so det will be postive, so logdet will be defined + X = S(3*rand(T, (n, n)) .+ 1) + X̄_acc = Diagonal(rand(T, (n, n))) # sensitivity is always a diagonal for these types + rrule_test(op, rand(T), (X, X̄_acc)) + end + @testset "return type" begin + X = S(3*rand(6, 6) .+ 1) + _, op_pullback = rrule(op, X) + X̄ = op_pullback(2.7)[2] + @test X̄ isa Diagonal + end + end + end end