Skip to content

Commit

Permalink
Fix conversion from Tridiagonal and Triangular to Matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 12, 2019
1 parent 1bd06e4 commit 864b973
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ the equality of the representation is equivalent to the equality of the objects
begin represented.
"""
isequal_canonical(a, b) = a == b
function isequal_canonical(a::AT, b::AT) where AT <: Union{Array, LinearAlgebra.Symmetric}
function isequal_canonical(a::AT, b::AT) where AT <: Union{Array, LinearAlgebra.Symmetric, LinearAlgebra.UpperTriangular, LinearAlgebra.LowerTriangular}
return all(zip(a, b)) do elements
return isequal_canonical(elements...)
end
Expand All @@ -80,6 +80,12 @@ end
function isequal_canonical(x::LinearAlgebra.Transpose, y::LinearAlgebra.Transpose)
return isequal_canonical(parent(x), parent(y))
end
function isequal_canonical(x::LinearAlgebra.Diagonal, y::LinearAlgebra.Diagonal)
return isequal_canonical(parent(x), parent(y))
end
function isequal_canonical(x::LinearAlgebra.Tridiagonal, y::LinearAlgebra.Tridiagonal)
return isequal_canonical(x.dl, y.dl) && isequal_canonical(x.d, y.d) && isequal_canonical(x.du, y.du)
end

include("rewrite.jl")
include("dispatch.jl")
Expand Down
60 changes: 59 additions & 1 deletion src/Test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,62 @@ function symmetric_matrix_uniform_scaling_test(x)
end
end

function triangular_test(x)
if !(x isa AbstractMatrix && size(x, 1) == size(x, 2))
return
end
n = LinearAlgebra.checksquare(x)
ut = LinearAlgebra.UpperTriangular(x)
add_test(ut, ut)
y = Matrix(ut)
for i in 1:n
for j in 1:(i - 1)
@test iszero(y[i, j])
@test MA.iszero!(y[i, j])
end
end
lt = LinearAlgebra.LowerTriangular(x)
add_test(lt, lt)
z = Matrix(lt)
for j in 1:n
for i in 1:(j - 1)
@test iszero(z[i, j])
@test MA.iszero!(z[i, j])
end
end
end

function diagonal_test(x)
if !(x isa AbstractVector && MA._one_indexed(x))
return
end
d = LinearAlgebra.Diagonal(x)
add_test(d, d)
y = Matrix(d)
t = LinearAlgebra.Tridiagonal(x[2:end], x, x[2:end])
add_test(t, t)
z = Matrix(t)
for i in eachindex(x)
@test MA.isequal_canonical(y[i, i], convert(eltype(y), x[i]))
@test MA.isequal_canonical(z[i, i], convert(eltype(z), x[i]))
end
n = length(x)
for j in 1:n
for i in 1:(j - 1)
@test iszero(y[i, j])
@test MA.iszero!(y[i, j])
@test iszero(y[j, i])
@test MA.iszero!(y[j, i])
if abs(i - j) > 1
@test iszero(z[i, j])
@test MA.iszero!(z[i, j])
@test iszero(z[j, i])
@test MA.iszero!(z[j, i])
end
end
end
end

const array_tests = Dict(
"matrix_vector_division" => matrix_vector_division_test,
"non_array" => non_array_test,
Expand All @@ -385,7 +441,9 @@ const array_tests = Dict(
"symmetric_unary" => symmetric_unary_test,
"symmetric_add" => symmetric_add_test,
"matrix_uniform_scaling" => matrix_uniform_scaling_test,
"symmetric_matrix_uniform_scaling" => symmetric_matrix_uniform_scaling_test
"symmetric_matrix_uniform_scaling" => symmetric_matrix_uniform_scaling_test,
"triangular" => triangular_test,
"diagonal" => diagonal_test
)

@test_suite array
4 changes: 3 additions & 1 deletion src/Test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ function _add_test(x, y)
end
function add_test(x, y)
_add_test(x, y)
_add_test(y, x)
if x !== y
_add_test(y, x)
end
end

function unary_test(x)
Expand Down
12 changes: 12 additions & 0 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,15 @@ end
function Base.:-(A::LinearAlgebra.Hermitian{<:AbstractMutable})
return LinearAlgebra.Hermitian(-parent(A), LinearAlgebra.sym_uplo(A.uplo))
end

# These three have specific methods that just redirect to `Matrix{T}` which
# does not work, e.g. if `zero(T)` has a different type than `T`.
function Base.Matrix(x::LinearAlgebra.Tridiagonal{T}) where T<:AbstractMutable
return Matrix{promote_type(promote_operation(zero, T), T)}(x)
end
function Base.Matrix(x::LinearAlgebra.UpperTriangular{T}) where T<:AbstractMutable
return Matrix{promote_type(promote_operation(zero, T), T)}(x)
end
function Base.Matrix(x::LinearAlgebra.LowerTriangular{T}) where T<:AbstractMutable
return Matrix{promote_type(promote_operation(zero, T), T)}(x)
end

0 comments on commit 864b973

Please sign in to comment.