Skip to content

Commit

Permalink
Specialize triu/tril for toeplitz matrix types (#123)
Browse files Browse the repository at this point in the history
* Specialize triu/tril

* Tests for immutable arrays

* Fix indexing and assignment in inplace triu/tril

* Fix indexing for triangular

* Reduce code duplication by reusing zeros!

* Tests for triangular

* Tests for different triangular part

* Split zero! for vectors and use for toeplitz
  • Loading branch information
jishnub committed Jan 8, 2024
1 parent 065ed7e commit 8006e93
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 56 deletions.
17 changes: 17 additions & 0 deletions src/ToeplitzMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ function isdiag(A::AbstractToeplitz)
all(iszero, @view vr[2:end]) && all(iszero, @view vc[2:end])
end

function zero!(v::AbstractVector, inds = eachindex(v))
if eltype(v) <: Number && isconcretetype(eltype(v))
if inds == eachindex(v)
v .= zero(eltype(v))
else
v[inds] .= zero(eltype(v))
end
else
if inds == eachindex(v)
v .= zero.(v)
else
@views v[inds] .= zero.(v[inds])
end
end
return v
end

"""
ToeplitzFactorization
Expand Down
39 changes: 17 additions & 22 deletions src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ function size(A::AbstractToeplitzSingleVector)
end

adjoint(A::AbstractToeplitzSingleVector) = transpose(conj(A))
function zero!(A::AbstractToeplitzSingleVector)
fill!(parent(A), zero(eltype(A)))
function zero!(A::AbstractToeplitzSingleVector, inds = eachindex(parent(A)))
zero!(parent(A), inds)
return A
end

Expand Down Expand Up @@ -201,6 +201,8 @@ for TYPE in (:UpperTriangular, :LowerTriangular)
end
end

_copymutable(A::AbstractToeplitzSingleVector) = basetype(A)(_copymutable(parent(A)))

# Triangular
for TYPE in (:AbstractMatrix, :AbstractVector)
@eval begin
Expand All @@ -218,39 +220,32 @@ for TYPE in (:AbstractMatrix, :AbstractVector)
end

# tril and triu
tril(A::Union{SymmetricToeplitz,Circulant}, k::Integer=0) = tril!(Toeplitz(A),k)
triu(A::Union{SymmetricToeplitz,Circulant}, k::Integer=0) = triu!(Toeplitz(A),k)
function _tridiff!(A::TriangularToeplitz, k::Integer)
i1, iend = firstindex(A.v), lastindex(A.v)
inds = max(i1, k+2):iend
if k >= 0
if isconcretetype(typeof(A.v))
for i in k+2:lastindex(A.v)
A.v[i] = zero(eltype(A))
end
else
A.v = vcat(A.v[1:k+1], zero(A.v[k+2:end]))
end
zero!(A, inds)
else
zero!(A)
end
A
end
tril!(A::UpperTriangularToeplitz, k::Integer) = _tridiff!(A,k)
triu!(A::LowerTriangularToeplitz, k::Integer) = _tridiff!(A,-k)
tril!(A::UpperTriangularToeplitz, k::Integer=0) = _tridiff!(A,k)
triu!(A::LowerTriangularToeplitz, k::Integer=0) = _tridiff!(A,-k)

function _trisame!(A::TriangularToeplitz, k::Integer)
i1, iend = firstindex(A.v), lastindex(A.v)
inds = i1:min(-k,iend)
if k < 0
if isconcretetype(typeof(A.v))
for i in 1:-k
A.v[i] = zero(eltype(A))
end
else
A.v=vcat(A.v[1:-k+1], zero(A.v[-k+2:end]))
end
zero!(A, inds)
end
A
end
tril!(A::LowerTriangularToeplitz, k::Integer) = _trisame!(A,k)
triu!(A::UpperTriangularToeplitz, k::Integer) = _trisame!(A,-k)
tril!(A::LowerTriangularToeplitz, k::Integer=0) = _trisame!(A,k)
triu!(A::UpperTriangularToeplitz, k::Integer=0) = _trisame!(A,-k)

tril(A::TriangularToeplitz, k::Integer=0) = tril!(_copymutable(A), k)
triu(A::TriangularToeplitz, k::Integer=0) = triu!(_copymutable(A), k)

isdiag(A::Union{Circulant, LowerTriangularToeplitz, SymmetricToeplitz}) = all(iszero, @view _vc(A)[2:end])
isdiag(A::UpperTriangularToeplitz) = all(iszero, @view _vr(A)[2:end])
58 changes: 26 additions & 32 deletions src/toeplitz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ Toeplitz(A::AbstractMatrix) = Toeplitz{eltype(A)}(A)
Toeplitz{T}(A::AbstractMatrix) where {T} = Toeplitz{T}(copy(_vc(A)), copy(_vr(A)))

AbstractToeplitz{T}(A::Toeplitz) where T = Toeplitz{T}(A)
convert(::Type{Toeplitz{T}}, A::AbstractToeplitz) where {T} = Toeplitz{T}(A)
convert(::Type{Toeplitz}, A::AbstractToeplitz) = Toeplitz(A)
convert(::Type{Toeplitz{T}}, A::AbstractToeplitz) where {T} = A isa Toeplitz{T} ? A : Toeplitz{T}(A)
convert(::Type{Toeplitz}, A::AbstractToeplitz) = A isa Toeplitz ? A : Toeplitz(A)

# Retrieve an entry
Base.@propagate_inbounds function getindex(A::AbstractToeplitz, i::Integer, j::Integer)
Expand All @@ -56,53 +56,47 @@ end

checknonaliased(A::Toeplitz) = Base.mightalias(A.vc, A.vr) && throw(ArgumentError("Cannot modify Toeplitz matrices in place with aliased data"))

function _copymutable(v::AbstractVector)
w = similar(v)
w .= v
return w
end
_copymutable(A::Toeplitz) = Toeplitz(_copymutable(A.vc), _copymutable(A.vr))

function tril!(A::Toeplitz, k::Integer=0)
checknonaliased(A)

if k >= 0
if isconcretetype(typeof(A.vr))
for i in k+2:lastindex(A.vr)
A.vr[i] = zero(eltype(A))
end
else
A.vr=vcat(A.vr[1:k+1], zero(A.vr[k+2:end]))
end
i1, iend = firstindex(A.vr), lastindex(A.vr)
inds = max(k+2,i1):iend
zero!(A.vr, inds)
else
fill!(A.vr, zero(eltype(A)))
if isconcretetype(typeof(A.vc))
for i in 1:-k
A.vc[i]=zero(eltype(A))
end
else
A.vc=vcat(zero(A.vc[1:-k]), A.vc[-k+1:end])
end
i1, iend = firstindex(A.vc), lastindex(A.vc)
inds = i1:min(-k,iend)
zero!(A.vr)
zero!(A.vc, inds)
end
A
end
function triu!(A::Toeplitz, k::Integer=0)
checknonaliased(A)

if k <= 0
if isconcretetype(typeof(A.vc))
for i in -k+2:lastindex(A.vc)
A.vc[i] = zero(eltype(A))
end
else
A.vc=vcat(A.vc[1:-k+1], zero(A.vc[-k+2:end]))
end
i1, iend = firstindex(A.vc), lastindex(A.vc)
inds = max(-k+2,i1):iend
zero!(A.vc, inds)
else
fill!(A.vc, zero(eltype(A)))
if isconcretetype(typeof(A.vr))
for i in 1:k
A.vr[i]=zero(eltype(A))
end
else
A.vr=vcat(zero(A.vr[1:k]), A.vr[k+1:end])
end
i1, iend = firstindex(A.vr), lastindex(A.vr)
inds = i1:min(k,iend)
zero!(A.vc)
zero!(A.vr, inds)
end
A
end

tril(A::AbstractToeplitz, k::Integer=0) = tril!(convert(Toeplitz, _copymutable(A)), k)
triu(A::AbstractToeplitz, k::Integer=0) = triu!(convert(Toeplitz, _copymutable(A)), k)

adjoint(A::AbstractToeplitz) = transpose(conj(A))
transpose(A::AbstractToeplitz) = Toeplitz(A.vr, A.vc)
function AbstractMatrix{T}(A::AbstractToeplitz) where {T}
Expand Down
51 changes: 49 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ end
end

@testset "General Interface" begin
for Toep in (:Toeplitz, :Circulant, :SymmetricToeplitz, :UpperTriangularToeplitz, :LowerTriangularToeplitz, :Hankel)
@testset for Toep in (:Toeplitz, :Circulant, :SymmetricToeplitz, :UpperTriangularToeplitz, :LowerTriangularToeplitz, :Hankel)
@eval (A = [1.0 3.0; 3.0 4.0]; TA=$Toep(A); A = Matrix(TA))
@eval (B = [2 1 ; 1 5 ]; TB=$Toep(B); B = Matrix(TB))

Expand Down Expand Up @@ -372,7 +372,54 @@ end
T=copy(TA)
end
@test fill!(Toeplitz(zeros(2,2)),1) == ones(2,2)


@testset "triu/tril for immutable" begin
A = Toeplitz(1:3, 1:4)
M = Matrix(A)
for k in -5:5
@test triu(A, k) == triu(M, k)
@test tril(A, k) == tril(M, k)
end
@testset for T in (Circulant, UpperTriangularToeplitz, LowerTriangularToeplitz, SymmetricToeplitz)
A = T(1:3)
M = Matrix(A)
for k in -5:5
@test triu(A, k) == triu(M, k)
@test tril(A, k) == tril(M, k)
end
end
end

@testset "triu/tril for non-concrete eltype" begin
T = Toeplitz{Union{Float64,ComplexF64}}(Float64.(1:3), Float64.(1:3))
M = Matrix(T)
for k in -5:5
@test tril(T, k) == tril(M, k)
@test triu(T, k) == triu(M, k)
end
@testset for T in (Circulant, SymmetricToeplitz)
A = T{Union{Float64,ComplexF64}}(Float64.(1:3))
M = Matrix(A)
for k in -5:5
@test triu(A, k) == triu(M, k)
@test tril(A, k) == tril(M, k)
end
end

A = UpperTriangularToeplitz{Union{Float64,ComplexF64}}(Float64.(1:3))
@test triu(A) == A
@test triu(A, -1) == A
@test triu(A, 1) == UpperTriangularToeplitz([0,2,3])
@test tril(A, 1) == UpperTriangularToeplitz([1,2,0])
@test tril(A, -1) == UpperTriangularToeplitz(zeros(3))
A = LowerTriangularToeplitz{Union{Float64,ComplexF64}}(Float64.(1:3))
@test tril(A) == A
@test tril(A,1) == A
@test tril(A,-1) == LowerTriangularToeplitz([0,2,3])
@test triu(A, 1) == LowerTriangularToeplitz(zeros(3))
@test triu(A, -1) == LowerTriangularToeplitz([1,2,0])
end

@testset "diag" begin
H = Hankel(1:11, 4, 8)
@test diag(H) 1:2:7
Expand Down

0 comments on commit 8006e93

Please sign in to comment.