Skip to content

Commit

Permalink
Faster eigvecs for SymTridiagonal toeplitz (#108)
Browse files Browse the repository at this point in the history
* Faster eigvecs for SymTridiagonal

* Test for more sizes
  • Loading branch information
jishnub committed Jul 27, 2023
1 parent d71545e commit fadca89
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
53 changes: 48 additions & 5 deletions src/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ function _eigvals_toeplitz(T)
return vals
end

_eigvec_eltype(A::Union{SymTridiagonal,
Symmetric{<:Any,<:Tridiagonal}}) = float(eltype(A))
_eigvec_eltype(A) = complex(float(eltype(A)))

_eigvec_prefactor(A, cm1, c1, m) = sqrt(complex(cm1/c1))^m
_eigvec_prefactor(A::Union{SymTridiagonal, Symmetric{<:Any, <:Tridiagonal}}, cm1, c1, m) = oneunit(_eigvec_eltype(A))

Expand All @@ -54,9 +58,6 @@ end
_eigvec_prefactors(A::Union{SymTridiagonal, Symmetric{<:Any, <:Tridiagonal}}, cm1, c1) =
Fill(_eigvec_prefactor(A, cm1, c1, 1), size(A,1))

_eigvec_eltype(A::SymTridiagonal) = float(eltype(A))
_eigvec_eltype(A) = complex(float(eltype(A)))

@static if !isdefined(Base, :eachcol)
eachcol(A) = (view(A,:,i) for i in axes(A,2))
end
Expand All @@ -79,9 +80,9 @@ function _eigvecs_toeplitz(T)
c1 = T[1,2] # superdiagonal
prefactors = _eigvec_prefactors(T, cm1, c1)
for q in axes(M,2)
qrev = n+1-q # match the default eigenvalue sorting
for j in 1:cld(n,2)
M[j, q] = prefactors[j] * sinpi(j*qrev/(n+1))
jphase = 2isodd(j) - 1
M[j, q] = prefactors[j] * jphase * sinpi(j * q/(n+1))
end
phase = iseven(n+q) ? 1 : -1
for j in cld(n,2)+1:n
Expand All @@ -92,6 +93,48 @@ function _eigvecs_toeplitz(T)
return M
end

function _eigvecs_toeplitz(T::Union{SymTridiagonal, Symmetric{<:Any,<:Tridiagonal}})
require_one_based_indexing(T)
n = checksquare(T)
M = Matrix{_eigvec_eltype(T)}(undef, n, n)
n == 0 && return M
n == 1 && return fill!(M, oneunit(eltype(M)))
for q in 1:cld(n,2)
for j in 1:q
jphase = 2isodd(j) - 1
M[j, q] = jphase * sinpi(j * q/(n+1))
end
end
for q in 1:cld(n,2)
for j in q+1:cld(n,2)
qphase = 2isodd(q) - 1
jphase = 2isodd(j) - 1
phase = qphase * jphase
M[j, q] = phase * M[q, j]
end
end
for q in 1:cld(n,2)
phase = iseven(n+q) ? 1 : -1
for j in cld(n,2)+1:n
M[j, q] = phase * M[n+1-j,q]
end
end
for q in cld(n,2)+1:n
for j in 1:cld(n,2)
qphase = 2isodd(q) - 1
jphase = 2isodd(j) - 1
phase = qphase * jphase
M[j, q] = phase * M[q, j]
end
phase = iseven(n+q) ? 1 : -1
for j in cld(n,2)+1:n
M[j, q] = phase * M[n+1-j,q]
end
end
_normalizecols!(M, T)
return M
end

function _eigvecs(A)
n = size(A,1)
if n <= 2 # repeated roots possible
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ end
@testset "eigen" begin
sortby = x -> (real(x), imag(x))
@testset "Tridiagonal Toeplitz" begin
_sizes = (1, 2, 6, 10)
_sizes = (1, 2, 5, 6, 10, 15)
sizes = VERSION >= v"1.6" ? (0, _sizes...) : _sizes
@testset for n in sizes
@testset "Tridiagonal" begin
Expand Down

0 comments on commit fadca89

Please sign in to comment.