Skip to content

Commit

Permalink
add Circulant math features (#43)
Browse files Browse the repository at this point in the history
* add Circulant math features

* add more tests
  • Loading branch information
vincentcp authored and andreasnoack committed Jan 31, 2019
1 parent cbe29c3 commit 8273041
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 3 deletions.
51 changes: 48 additions & 3 deletions src/ToeplitzMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module ToeplitzMatrices
using Compat, StatsBase, Compat.LinearAlgebra


import Base: convert, *, \, getindex, print_matrix, size, Matrix
import Compat.LinearAlgebra: BlasReal, DimensionMismatch, tril, triu, inv,
import Base: convert, *, \, getindex, print_matrix, size, Matrix, +, -, copy, similar, sqrt
import Compat.LinearAlgebra: BlasReal, DimensionMismatch, tril, triu, inv, pinv, eigvals,
cholesky, cholesky!
import Compat: copyto!

Expand All @@ -18,7 +18,7 @@ if VERSION < v"0.7-"
else
using FFTW
using FFTW: Plan
import LinearAlgebra: mul!, ldiv!
import LinearAlgebra: mul!, ldiv!, eigvals, pinv
flipdim(A, d) = reverse(A, dims=d)
end

Expand Down Expand Up @@ -384,6 +384,51 @@ function chan(A::AbstractMatrix{T}) where T
return Circulant(v)
end

function pinv(C::Circulant{T}, tolerance::T = eps(T)) where T<:Real
vdft = copy(C.vcvr_dft)
vdft[abs.(vdft).<tolerance] .= Inf
vdft .= 1 ./ vdft
return Circulant(real(C.dft \ vdft), copy(vdft), similar(vdft), C.dft)
end

function pinv(C::Circulant{T}, tolerance::Real = eps(real(T))) where T<:Number
vdft = copy(C.vcvr_dft)
vdft[abs.(vdft).<tolerance] .= Inf
vdft .= 1 ./ vdft
return Circulant(C.dft \ vdft, copy(vdft), similar(vdft), C.dft)
end

eigvals(C::Circulant) = copy(C.vcvr_dft)
sqrt(C::Circulant{T}) where T<:Real = Circulant(real(ifft(sqrt.(C.vcvr_dft))))
sqrt(C::Circulant) = Circulant(ifft(sqrt.(C.vcvr_dft)))
copy(C::Circulant) = Circulant(copy(C.vc))
similar(C::Circulant) = Circulant(similar(C.vc))
function copyto!(dest::Circulant{U,S}, src::Circulant{V,S}) where {U,V,S}
copyto!(dest.vc, src.vc)
copyto!(dest.vcvr_dft, src.vcvr_dft)
end

function (+)(C1::Circulant, C2::Circulant)
@boundscheck (size(C1)==size(C2)) || throw(BoundsError())
Circulant(C1.vc+C2.vc)
end

function (-)(C1::Circulant, C2::Circulant)
@boundscheck (size(C1)==size(C2)) || throw(BoundsError())
Circulant(C1.vc-C2.vc)
end

(-)(C::Circulant) = Circulant(-C.vc)

function (*)(C1::Circulant, C2::Circulant)
@boundscheck (size(C1)==size(C2)) || throw(BoundsError())
Circulant(ifft(C1.vcvr_dft.*C2.vcvr_dft))
end

(*)(scalar::Number, C::Circulant) = Circulant(scalar*C.vc)
(*)(C::Circulant,scalar::Number) = Circulant(scalar*C.vc)


# Triangular
mutable struct TriangularToeplitz{T<:Number,S<:Number} <: AbstractToeplitz{T}
ve::Vector{T}
Expand Down
80 changes: 80 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ if VERSION < v"0.7-"
const ldiv! = A_ldiv_B!
end

using Compat: copyto!

ns = 101
nl = 2000

Expand Down Expand Up @@ -222,10 +224,88 @@ end
@test Hankel(Hankel(A)) == Hankel(A)
end

@testset "Circulant mathematics" begin
C1 = Circulant(rand(5))
C2 = Circulant(rand(5))
C3 = Circulant{ComplexF64}(rand(5))
C4 = Circulant(ones(5))
C5 = Circulant{ComplexF64}(ones(5))
M1 = Matrix(C1)
M2 = Matrix(C2)
M3 = Matrix(C3)
M4 = Matrix(C4)
M5 = Matrix(C5)

C = C1*C2
@test C isa Circulant
@test C M1*M2

C = C1-C2
@test C isa Circulant
@test C M1-M2

C = C1+C2
@test C isa Circulant
@test C M1+M2

C = 2C1
@test C isa Circulant
@test C 2M1

C = C1*2
@test C isa Circulant
@test C M1*2

C = -C1
@test C isa Circulant
@test C -M1

C = inv(C1)
@test C isa Circulant
@test C inv(M1)
C = inv(C3)
@test C isa Circulant
@test C inv(M3)

C = pinv(C1)
@test C isa Circulant
@test C pinv(M1)
C = pinv(C3)
@test C isa Circulant
@test C pinv(M3)
C = pinv(C4)
@test C isa Circulant
@test C pinv(M4)
C = pinv(C5)
@test C isa Circulant
@test C pinv(M5)

C = sqrt(C1)
@test C isa Circulant
@test C*C C1
C = sqrt(C3)
@test C isa Circulant
@test C*C C3

C = copy(C1)
@test C isa Circulant
C2 = similar(C1)
copyto!(C2, C1)
@test C1 C2

v1 = eigvals(C1)
v2 = eigvals(M1)
for v1i in v1
@test minimum(abs.(v1i .- v2)) < sqrt(eps(Float64))
end
end

if VERSION v"0.7"
@testset "Cholesky" begin
T = SymmetricToeplitz(exp.(-0.5 .* range(0, stop=5, length=100)))
@test cholesky(T).U cholesky(Matrix(T)).U
@test cholesky(T).L cholesky(Matrix(T)).L
end


end

0 comments on commit 8273041

Please sign in to comment.