Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize triu/tril for toeplitz matrix types #123

Merged
merged 8 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/ToeplitzMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ function isdiag(A::AbstractToeplitz)
all(iszero, @view vr[2:end]) && all(iszero, @view vc[2:end])
end

function zero!(v::AbstractVector, inds = eachindex(v))
if eltype(v) <: Number && isconcretetype(eltype(v))
if inds == eachindex(v)
v .= zero(eltype(v))
else
v[inds] .= zero(eltype(v))
end
else
if inds == eachindex(v)
v .= zero.(v)
else
@views v[inds] .= zero.(v[inds])
end
end
return v
end

"""
ToeplitzFactorization

Expand Down
39 changes: 17 additions & 22 deletions src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ function size(A::AbstractToeplitzSingleVector)
end

adjoint(A::AbstractToeplitzSingleVector) = transpose(conj(A))
function zero!(A::AbstractToeplitzSingleVector)
fill!(parent(A), zero(eltype(A)))
function zero!(A::AbstractToeplitzSingleVector, inds = eachindex(parent(A)))
zero!(parent(A), inds)
return A
end

Expand Down Expand Up @@ -201,6 +201,8 @@ for TYPE in (:UpperTriangular, :LowerTriangular)
end
end

_copymutable(A::AbstractToeplitzSingleVector) = basetype(A)(_copymutable(parent(A)))

# Triangular
for TYPE in (:AbstractMatrix, :AbstractVector)
@eval begin
Expand All @@ -218,39 +220,32 @@ for TYPE in (:AbstractMatrix, :AbstractVector)
end

# tril and triu
tril(A::Union{SymmetricToeplitz,Circulant}, k::Integer=0) = tril!(Toeplitz(A),k)
triu(A::Union{SymmetricToeplitz,Circulant}, k::Integer=0) = triu!(Toeplitz(A),k)
function _tridiff!(A::TriangularToeplitz, k::Integer)
i1, iend = firstindex(A.v), lastindex(A.v)
inds = max(i1, k+2):iend
if k >= 0
if isconcretetype(typeof(A.v))
for i in k+2:lastindex(A.v)
A.v[i] = zero(eltype(A))
end
else
A.v = vcat(A.v[1:k+1], zero(A.v[k+2:end]))
end
zero!(A, inds)
else
zero!(A)
end
A
end
tril!(A::UpperTriangularToeplitz, k::Integer) = _tridiff!(A,k)
triu!(A::LowerTriangularToeplitz, k::Integer) = _tridiff!(A,-k)
tril!(A::UpperTriangularToeplitz, k::Integer=0) = _tridiff!(A,k)
triu!(A::LowerTriangularToeplitz, k::Integer=0) = _tridiff!(A,-k)

function _trisame!(A::TriangularToeplitz, k::Integer)
i1, iend = firstindex(A.v), lastindex(A.v)
inds = i1:min(-k,iend)
if k < 0
if isconcretetype(typeof(A.v))
for i in 1:-k
A.v[i] = zero(eltype(A))
end
else
A.v=vcat(A.v[1:-k+1], zero(A.v[-k+2:end]))
end
zero!(A, inds)
end
A
end
tril!(A::LowerTriangularToeplitz, k::Integer) = _trisame!(A,k)
triu!(A::UpperTriangularToeplitz, k::Integer) = _trisame!(A,-k)
tril!(A::LowerTriangularToeplitz, k::Integer=0) = _trisame!(A,k)
triu!(A::UpperTriangularToeplitz, k::Integer=0) = _trisame!(A,-k)

tril(A::TriangularToeplitz, k::Integer=0) = tril!(_copymutable(A), k)
triu(A::TriangularToeplitz, k::Integer=0) = triu!(_copymutable(A), k)

isdiag(A::Union{Circulant, LowerTriangularToeplitz, SymmetricToeplitz}) = all(iszero, @view _vc(A)[2:end])
isdiag(A::UpperTriangularToeplitz) = all(iszero, @view _vr(A)[2:end])
58 changes: 26 additions & 32 deletions src/toeplitz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ Toeplitz(A::AbstractMatrix) = Toeplitz{eltype(A)}(A)
Toeplitz{T}(A::AbstractMatrix) where {T} = Toeplitz{T}(copy(_vc(A)), copy(_vr(A)))

AbstractToeplitz{T}(A::Toeplitz) where T = Toeplitz{T}(A)
convert(::Type{Toeplitz{T}}, A::AbstractToeplitz) where {T} = Toeplitz{T}(A)
convert(::Type{Toeplitz}, A::AbstractToeplitz) = Toeplitz(A)
convert(::Type{Toeplitz{T}}, A::AbstractToeplitz) where {T} = A isa Toeplitz{T} ? A : Toeplitz{T}(A)
convert(::Type{Toeplitz}, A::AbstractToeplitz) = A isa Toeplitz ? A : Toeplitz(A)

# Retrieve an entry
Base.@propagate_inbounds function getindex(A::AbstractToeplitz, i::Integer, j::Integer)
Expand All @@ -56,53 +56,47 @@ end

checknonaliased(A::Toeplitz) = Base.mightalias(A.vc, A.vr) && throw(ArgumentError("Cannot modify Toeplitz matrices in place with aliased data"))

function _copymutable(v::AbstractVector)
w = similar(v)
w .= v
return w
end
_copymutable(A::Toeplitz) = Toeplitz(_copymutable(A.vc), _copymutable(A.vr))

function tril!(A::Toeplitz, k::Integer=0)
checknonaliased(A)

if k >= 0
if isconcretetype(typeof(A.vr))
for i in k+2:lastindex(A.vr)
A.vr[i] = zero(eltype(A))
end
else
A.vr=vcat(A.vr[1:k+1], zero(A.vr[k+2:end]))
end
i1, iend = firstindex(A.vr), lastindex(A.vr)
inds = max(k+2,i1):iend
zero!(A.vr, inds)
else
fill!(A.vr, zero(eltype(A)))
if isconcretetype(typeof(A.vc))
for i in 1:-k
A.vc[i]=zero(eltype(A))
end
else
A.vc=vcat(zero(A.vc[1:-k]), A.vc[-k+1:end])
end
i1, iend = firstindex(A.vc), lastindex(A.vc)
inds = i1:min(-k,iend)
zero!(A.vr)
zero!(A.vc, inds)
end
A
end
function triu!(A::Toeplitz, k::Integer=0)
checknonaliased(A)

if k <= 0
if isconcretetype(typeof(A.vc))
for i in -k+2:lastindex(A.vc)
A.vc[i] = zero(eltype(A))
end
else
A.vc=vcat(A.vc[1:-k+1], zero(A.vc[-k+2:end]))
end
i1, iend = firstindex(A.vc), lastindex(A.vc)
inds = max(-k+2,i1):iend
zero!(A.vc, inds)
else
fill!(A.vc, zero(eltype(A)))
if isconcretetype(typeof(A.vr))
for i in 1:k
A.vr[i]=zero(eltype(A))
end
else
A.vr=vcat(zero(A.vr[1:k]), A.vr[k+1:end])
end
i1, iend = firstindex(A.vr), lastindex(A.vr)
inds = i1:min(k,iend)
zero!(A.vc)
zero!(A.vr, inds)
end
A
end

tril(A::AbstractToeplitz, k::Integer=0) = tril!(convert(Toeplitz, _copymutable(A)), k)
triu(A::AbstractToeplitz, k::Integer=0) = triu!(convert(Toeplitz, _copymutable(A)), k)

adjoint(A::AbstractToeplitz) = transpose(conj(A))
transpose(A::AbstractToeplitz) = Toeplitz(A.vr, A.vc)
function AbstractMatrix{T}(A::AbstractToeplitz) where {T}
Expand Down
51 changes: 49 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ end
end

@testset "General Interface" begin
for Toep in (:Toeplitz, :Circulant, :SymmetricToeplitz, :UpperTriangularToeplitz, :LowerTriangularToeplitz, :Hankel)
@testset for Toep in (:Toeplitz, :Circulant, :SymmetricToeplitz, :UpperTriangularToeplitz, :LowerTriangularToeplitz, :Hankel)
@eval (A = [1.0 3.0; 3.0 4.0]; TA=$Toep(A); A = Matrix(TA))
@eval (B = [2 1 ; 1 5 ]; TB=$Toep(B); B = Matrix(TB))

Expand Down Expand Up @@ -372,7 +372,54 @@ end
T=copy(TA)
end
@test fill!(Toeplitz(zeros(2,2)),1) == ones(2,2)


@testset "triu/tril for immutable" begin
A = Toeplitz(1:3, 1:4)
M = Matrix(A)
for k in -5:5
@test triu(A, k) == triu(M, k)
@test tril(A, k) == tril(M, k)
end
@testset for T in (Circulant, UpperTriangularToeplitz, LowerTriangularToeplitz, SymmetricToeplitz)
A = T(1:3)
M = Matrix(A)
for k in -5:5
@test triu(A, k) == triu(M, k)
@test tril(A, k) == tril(M, k)
end
end
end

@testset "triu/tril for non-concrete eltype" begin
T = Toeplitz{Union{Float64,ComplexF64}}(Float64.(1:3), Float64.(1:3))
M = Matrix(T)
for k in -5:5
@test tril(T, k) == tril(M, k)
@test triu(T, k) == triu(M, k)
end
@testset for T in (Circulant, SymmetricToeplitz)
A = T{Union{Float64,ComplexF64}}(Float64.(1:3))
M = Matrix(A)
for k in -5:5
@test triu(A, k) == triu(M, k)
@test tril(A, k) == tril(M, k)
end
end

A = UpperTriangularToeplitz{Union{Float64,ComplexF64}}(Float64.(1:3))
@test triu(A) == A
@test triu(A, -1) == A
@test triu(A, 1) == UpperTriangularToeplitz([0,2,3])
@test tril(A, 1) == UpperTriangularToeplitz([1,2,0])
@test tril(A, -1) == UpperTriangularToeplitz(zeros(3))
A = LowerTriangularToeplitz{Union{Float64,ComplexF64}}(Float64.(1:3))
@test tril(A) == A
@test tril(A,1) == A
@test tril(A,-1) == LowerTriangularToeplitz([0,2,3])
@test triu(A, 1) == LowerTriangularToeplitz(zeros(3))
@test triu(A, -1) == LowerTriangularToeplitz([1,2,0])
end

@testset "diag" begin
H = Hankel(1:11, 4, 8)
@test diag(H) ≡ 1:2:7
Expand Down