diff --git a/Project.toml b/Project.toml index cbcb4e7..d22bca3 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,12 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +[weakdeps] +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[extensions] +ToeplitzMatricesStatsBaseExt = "StatsBase" + [compat] AbstractFFTs = "0.4, 0.5, 1" Aqua = "0.6" @@ -23,7 +29,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "FFTW", "Pkg", "Random", "Test"] +test = ["Aqua", "FFTW", "Pkg", "Random", "StatsBase", "Test"] diff --git a/ext/ToeplitzMatricesStatsBaseExt.jl b/ext/ToeplitzMatricesStatsBaseExt.jl new file mode 100644 index 0000000..b6aa746 --- /dev/null +++ b/ext/ToeplitzMatricesStatsBaseExt.jl @@ -0,0 +1,26 @@ +module ToeplitzMatricesStatsBaseExt + +using ToeplitzMatrices +using StatsBase + +function StatsBase.levinson(A::AbstractToeplitz, B::AbstractVecOrMat) + StatsBase.levinson!(zeros(size(B)), A, copy(B)) +end + +# extend levinson +function StatsBase.levinson!(x::StridedVector, A::SymmetricToeplitz, b::StridedVector) + StatsBase.levinson!(A.vc, b, x) +end + +function StatsBase.levinson!(C::StridedMatrix, A::SymmetricToeplitz, B::StridedMatrix) + n = size(B, 2) + if n != size(C, 2) + throw(DimensionMismatch("input and output matrices must have same number of columns")) + end + for j = 1:n + StatsBase.levinson!(view(C, :, j), A, view(B, :, j)) + end + C +end + +end diff --git a/src/ToeplitzMatrices.jl b/src/ToeplitzMatrices.jl index 6ca3f71..b96ffa7 100644 --- a/src/ToeplitzMatrices.jl +++ b/src/ToeplitzMatrices.jl @@ -1,5 +1,4 @@ module ToeplitzMatrices -# import StatsBase: levinson!, levinson import DSP: conv import Base: adjoint, convert, transpose, size, getindex, similar, copy, getproperty, inv, sqrt, copyto!, reverse, conj, zero, fill!, checkbounds, real, imag, isfinite, DimsInteger, iszero @@ -15,7 +14,6 @@ import LinearAlgebra: issymmetric, ishermitian import LinearAlgebra: eigvals, eigvecs, eigen import AbstractFFTs: Plan, plan_fft! -import StatsBase using FillArrays using LinearAlgebra @@ -92,4 +90,8 @@ maybereal(::Type{<:Real}, x) = real(x) include("directLinearSolvers.jl") +if !isdefined(Base, :get_extension) + include("../ext/ToeplitzMatricesStatsBaseExt.jl") +end + end #module diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index e35aa50..15b84bf 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -135,8 +135,6 @@ function ldiv!(A::Toeplitz, b::StridedVector) copyto!(b, IterativeLinearSolvers.cgs(A, zeros(eltype(b), length(b)), b, preconditioner, 1000, 100eps())[1]) end -StatsBase.levinson(A::AbstractToeplitz, B::AbstractVecOrMat) = StatsBase.levinson!(zeros(size(B)), A, copy(B)) - # SymmetricToeplitz function factorize(A::SymmetricToeplitz{T}) where {T<:Number} @@ -174,26 +172,13 @@ end """ cholesky(T::SymmetricToeplitz) -Implementation of the Bareiss Algorhithm, adapted from "On the stability of the Bareiss and +Implementation of the Bareiss Algorithm, adapted from "On the stability of the Bareiss and related Toeplitz factorization algorithms", Bojanczyk et al, 1993. """ function cholesky(T::SymmetricToeplitz) return cholesky!(Matrix{eltype(T)}(undef, size(T, 1), size(T, 1)), T) end -# extend levinson -StatsBase.levinson!(x::StridedVector, A::SymmetricToeplitz, b::StridedVector) = StatsBase.levinson!(A.vc, b, x) -function StatsBase.levinson!(C::StridedMatrix, A::SymmetricToeplitz, B::StridedMatrix) - n = size(B, 2) - if n != size(C, 2) - throw(DimensionMismatch("input and output matrices must have same number of columns")) - end - for j = 1:n - StatsBase.levinson!(view(C, :, j), A, view(B, :, j)) - end - C -end - # circulant const CirculantFactorization{T, V<:AbstractVector{T}} = ToeplitzFactorization{T,Circulant{T,V}} function factorize(C::Circulant) diff --git a/test/runtests.jl b/test/runtests.jl index fd8e49d..d777a8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,10 @@ using FFTW: fft @testset "code quality" begin Aqua.test_ambiguities(ToeplitzMatrices, recursive=false) # Aqua.test_all includes Base and Core in ambiguity testing - Aqua.test_all(ToeplitzMatrices, ambiguities=false, piracy=false) + Aqua.test_all(ToeplitzMatrices, ambiguities=false, piracy=false, + # only test formatting on VERSION >= v1.7 + # https://github.com/JuliaTesting/Aqua.jl/issues/105#issuecomment-1551405866 + project_toml_formatting = VERSION >= v"1.7") end ns = 101