Skip to content

Commit

Permalink
Merge 1b6920b into f63f385
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 16, 2020
2 parents f63f385 + 1b6920b commit 28546fc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,24 @@ function rrule(::typeof(tril), A::AbstractMatrix)
end
return tril(A), tril_pullback
end

_diag_view(X) = view(X, diagind(X))
_diag_view(X::Diagonal) = parent(X) #Diagonal wraps a Vector of just Diagonal elements

function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular})
y = det(X)
s = conj!(y ./ _diag_view(X))
function det_pullback(ȳ)
return (NO_FIELDS, Diagonal(ȳ .* s))
end
return y, det_pullback
end

function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
y = logdet(X)
s = conj!(one(eltype(X)) ./ _diag_view(X))
function logdet_pullback(ȳ)
return (NO_FIELDS, Diagonal(ȳ .* s))
end
return y, logdet_pullback
end
18 changes: 18 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,22 @@
rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing))
end
end

@testset "det and logdet $S" for S in (Diagonal, UpperTriangular, LowerTriangular)
@testset "$op" for op in (det, logdet)
@testset "$T" for T in (Float64, ComplexF64)
n = 5
# rand (not randn) so det will be postive, so logdet will be defined
X = S(3*rand(T, (n, n)) .+ 1)
X̄_acc = Diagonal(rand(T, (n, n))) # sensitivity is always a diagonal for these types
rrule_test(op, rand(T), (X, X̄_acc))
end
@testset "return type" begin
X = S(3*rand(6, 6) .+ 1)
_, op_pullback = rrule(op, X)
= op_pullback(2.7)[2]
@testisa Diagonal
end
end
end
end

0 comments on commit 28546fc

Please sign in to comment.