Skip to content

Commit

Permalink
correct logdet and det for Complex triangular/diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 16, 2020
1 parent 319245d commit 1b6920b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ _diag_view(X::Diagonal) = parent(X) #Diagonal wraps a Vector of just Diagonal e

function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular})
y = det(X)
s = y ./ _diag_view(X)
s = conj!(y ./ _diag_view(X))
function det_pullback(ȳ)
return (NO_FIELDS, Diagonal(ȳ .* s))
end
Expand All @@ -289,7 +289,7 @@ end

function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
y = logdet(X)
s = one(eltype(X)) ./ _diag_view(X)
s = conj!(one(eltype(X)) ./ _diag_view(X))
function logdet_pullback(ȳ)
return (NO_FIELDS, Diagonal(ȳ .* s))
end
Expand Down
16 changes: 9 additions & 7 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,17 @@
end
end

@testset "det and logdet $T" for T in (Diagonal, UpperTriangular, LowerTriangular)
n = 5
# rand (not randn) so det will be postive, so logdet will be defined
X = T(3*rand(n, n) .+ 1)
X̄_acc = Diagonal(rand(n, n)) # sensitivity is always a diagonal for these types
@testset "det and logdet $S" for S in (Diagonal, UpperTriangular, LowerTriangular)
@testset "$op" for op in (det, logdet)
rrule_test(op, 2.7, (X, X̄_acc))

@testset "$T" for T in (Float64, ComplexF64)
n = 5
# rand (not randn) so det will be postive, so logdet will be defined
X = S(3*rand(T, (n, n)) .+ 1)
X̄_acc = Diagonal(rand(T, (n, n))) # sensitivity is always a diagonal for these types
rrule_test(op, rand(T), (X, X̄_acc))
end
@testset "return type" begin
X = S(3*rand(6, 6) .+ 1)
_, op_pullback = rrule(op, X)
= op_pullback(2.7)[2]
@testisa Diagonal
Expand Down

0 comments on commit 1b6920b

Please sign in to comment.