diff --git a/src/SymWoodburyMatrices.jl b/src/SymWoodburyMatrices.jl index 038d2ab..9f3c4be 100644 --- a/src/SymWoodburyMatrices.jl +++ b/src/SymWoodburyMatrices.jl @@ -6,8 +6,8 @@ using Base.LinAlg.BLAS:gemm!,gemm,axpy! Represents a matrix of the form A + BDBᵀ. """ type SymWoodbury{T,AType, BType, DType} <: AbstractMatrix{T} - A::AType; - B::BType; + A::AType; + B::BType; D::DType; end @@ -28,7 +28,7 @@ end function SymWoodbury{T}(A, B::AbstractVector{T}, D::T) n = size(A, 1) k = 1 - if size(A, 2) != n || length(B) != n + if size(A, 2) != n || length(B) != n throw(DimensionMismatch("Sizes of B ($(size(B))) and/or D ($(size(D))) are inconsistent with A ($(size(A)))")) end SymWoodbury{T,typeof(A),typeof(B),typeof(D)}(A,B,D) @@ -47,14 +47,14 @@ function calc_inv(A, B, D) SymWoodbury(W,X,Z); end -Base.inv{T<:Any, AType<:Any, BType<:AbstractVector, DType<:Real}(O::SymWoodbury{T,AType,BType,DType}) = +Base.inv{T<:Any, AType<:Any, BType<:AbstractVector, DType<:Real}(O::SymWoodbury{T,AType,BType,DType}) = calc_inv(O.A, O.B, O.D) -Base.inv{T<:Any, AType<:Any, BType<:Any, DType<:AbstractMatrix}(O::SymWoodbury{T,AType,BType,DType}) = +Base.inv{T<:Any, AType<:Any, BType<:Any, DType<:AbstractMatrix}(O::SymWoodbury{T,AType,BType,DType}) = calc_inv(O.A, O.B, O.D) -# D is typically small, so this is acceptable. -Base.inv{T<:Any, AType<:Any, BType<:Any, DType<:SparseMatrixCSC}(O::SymWoodbury{T,AType,BType,DType}) = +# D is typically small, so this is acceptable. +Base.inv{T<:Any, AType<:Any, BType<:Any, DType<:SparseMatrixCSC}(O::SymWoodbury{T,AType,BType,DType}) = calc_inv(O.A, O.B, full(O.D)); \(W::SymWoodbury, R::StridedVecOrMat) = inv(W)*R @@ -100,7 +100,7 @@ on evaluation, i.e. `liftFactor(A)(x)` is the same as `inv(A)*x`. """ liftFactor(O::SymWoodbury) = liftFactorVars(O.A,O.B,O.D) -function *{T}(O::SymWoodbury{T}, x::Union{Matrix,Vector,SubArray}) +function *{T}(O::SymWoodbury{T}, x::Union{Matrix,Vector,SubArray}) o = O.A*x; plusBDBtx!(o, O.B, O.D, x) return o @@ -112,7 +112,7 @@ end plusBDBtx!(o, B::AbstractVector, D, x) = plusBDBtx!(o, reshape(B,size(B,1),1),D,x) -# Optimization - use specialized BLAS package +# Optimization - use specialized BLAS package function plusBDBtx!(o, B::Array{Float64,2}, D, x::Array{Float64,2}) w = D*gemm('T','N',B,x); gemm!('N','N',1.,B,w,1., o) @@ -124,12 +124,12 @@ function plusBDBtx!(o, B::Array{Float64,1}, d::Real, x::Union{Array{Float64,2}, axpy!(vecdot(B,x)*d, B, o) else w = d*gemm('T', 'N' ,reshape(B, size(B,1), 1),x); - gemm!('N','N',1.,B,w,1., o) + gemm!('N','N',1.,B,w,1., o) end end Base.Ac_mul_B{T}(O1::SymWoodbury{T}, x::AbstractVector{T}) = O1*x -Base.Ac_mul_B{T}(O1::SymWoodbury{T}, x::AbstractMatrix{T}) = O1*x +Base.Ac_mul_B(O1::SymWoodbury, x::AbstractMatrix) = O1*x +(O::SymWoodbury, M::SymWoodbury) = SymWoodbury(O.A + M.A, [O.B M.B], cat([1,2],O.D,M.D) ); @@ -148,7 +148,7 @@ function square(O::SymWoodbury) AB = O.A*O.B Z = [(AB + O.B) (AB - O.B)] R = O.D*(O.B'*O.B)*O.D/4 - D = [ O.D/2 + R -R + D = [ O.D/2 + R -R -R -O.D/2 + R ] SymWoodbury(A, Z, D) end @@ -159,11 +159,11 @@ except when they are the same, i.e. the user writes A'A or A*A' or A*A. Z(A + B*D*Bᵀ) = ZA + ZB*D*Bᵀ -This package will not support support left multiplication by a generic +This package will not support support left multiplication by a generic matrix, to keep return types consistent. """ function *(O1::SymWoodbury, O2::SymWoodbury) - if (O1 === O2) + if (O1 === O2) return square(O1) else if O1.A == O2.A && O1.B == O2.B && O1.D == O2.D @@ -174,7 +174,6 @@ function *(O1::SymWoodbury, O2::SymWoodbury) end end -Base.Ac_mul_B(O1::SymWoodbury, O2::SymWoodbury) = O1*O2 Base.A_mul_Bc(O1::SymWoodbury, O2::SymWoodbury) = O1*O2 conjm(O::SymWoodbury, M) = SymWoodbury(M*O.A*M', M*O.B, O.D); @@ -187,6 +186,6 @@ Base.sparse(O::SymWoodbury) = sparse(full(O)) # returns a pointer to the original matrix, this is consistent with the # behavior of Symmetric in Base. -Base.ctranspose(O::SymWoodbury) = O +Base.ctranspose(O::SymWoodbury) = O -Base.det(W::SymWoodbury) = det(convert(Woodbury, W)) \ No newline at end of file +Base.det(W::SymWoodbury) = det(convert(Woodbury, W)) diff --git a/src/WoodburyMatrices.jl b/src/WoodburyMatrices.jl index 5468087..f2d8750 100644 --- a/src/WoodburyMatrices.jl +++ b/src/WoodburyMatrices.jl @@ -38,15 +38,16 @@ function Woodbury{T}(A, U::AbstractMatrix{T}, C, V::AbstractMatrix{T}) end Cp = inv(inv(C) + V*(A\U)) # temporary space for allocation-free solver - tmpN1 = Array(T, N) - tmpN2 = Array(T, N) - tmpk1 = Array(T, k) - tmpk2 = Array(T, k) + tmpN1 = Array{T,1}(N) + tmpN2 = Array{T,1}(N) + tmpk1 = Array{T,1}(k) + tmpk2 = Array{T,1}(k) # don't copy A, it could be huge Woodbury{T,typeof(A),typeof(U),typeof(V),typeof(C),typeof(Cp)}(A, copy(U), copy(C), Cp, copy(V), tmpN1, tmpN2, tmpk1, tmpk2) end Woodbury{T}(A, U::Vector{T}, C, V::Matrix{T}) = Woodbury(A, reshape(U, length(U), 1), C, V) +@static if isdefined(:RowVector) Woodbury(A, U::AbstractVector, C, V::RowVector) = Woodbury(A, U, C, Matrix(V)) end size(W::Woodbury) = size(W.A) diff --git a/test/runtests.jl b/test/runtests.jl index c902f1e..4b64253 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,14 +45,14 @@ for elty in (Float32, Float64, Complex64, Complex128, Int) # Woodbury W = Woodbury(T, U, C, V) F = full(W) - @test_approx_eq W*v F*v + @test W*v ≈ F*v iFv = F\v @test norm(W\v - iFv)/norm(iFv) <= n*cond(F)*ε # Revisit. Condition number is wrong @test abs((det(W) - det(F))/det(F)) <= n*cond(F)*ε # Revisit. Condition number is wrong iWv = similar(iFv) if elty != Int iWv = A_ldiv_B!(W, copy(v)) - @test_approx_eq iWv iFv + @test iWv ≈ iFv end end @@ -110,7 +110,7 @@ for i in 1:5 # repeat 5 times by_hand[i, j] = v end - @test maxabs(out - by_hand) == 0.0 + @test maximum(abs,out - by_hand) == 0.0 end include("runtests_sym.jl") diff --git a/test/runtests_sym.jl b/test/runtests_sym.jl index 2858f5c..bbc40d3 100644 --- a/test/runtests_sym.jl +++ b/test/runtests_sym.jl @@ -25,53 +25,51 @@ for elty in (Float32, Float64, Complex64, Complex128, Int), AMat in (diagm,) ε = eps(abs2(float(one(elty)))) A = AMat(a) - + # Woodbury for W in (SymWoodbury(A, B, D), SymWoodbury(A, B[:,1][:], 2.)) F = full(W) - - @test_approx_eq (2*W)*v 2*(W*v) - @test_approx_eq W'*v W*v - @test_approx_eq (W'W)*v full(W)*(full(W)*v) - @test_approx_eq (W*W)*v full(W)*(full(W)*v) - @test_approx_eq (W*W')*v full(W)*(full(W)*v) - @test_approx_eq W[1:3,1:3]*v[1:3] full(W)[1:3,1:3]*v[1:3] - @test_approx_eq sparse(W) full(W) + @test (2*W)*v ≈ 2*(W*v) + @test W'*v ≈ W*v + @test (W'W)*v ≈ full(W)*(full(W)*v) + @test (W*W)*v ≈ full(W)*(full(W)*v) + @test (W*W')*v ≈ full(W)*(full(W)*v) + @test W[1:3,1:3]*v[1:3] ≈ full(W)[1:3,1:3]*v[1:3] + @test sparse(W) ≈ full(W) @test W === W' - @test_approx_eq W*eye(n) full(W) - @test_approx_eq W'*eye(n) full(W) + @test W*eye(n) ≈ full(W) + @test W'*eye(n) ≈ full(W) Z = randn(n,n) - @test_approx_eq full(W*Z) full(W)*Z + @test full(W*Z) ≈ full(W)*Z R = rand(n,n) for v = (rand(n, 1), view(rand(n,1), 1:n), view(rand(n,2),1:n,1:2)) - @test_approx_eq (2*W)*v 2*(W*v) - @test_approx_eq (W*2)*v 2*(W*v) - @test_approx_eq (W'W)*v full(W)*(full(W)*v) - @test_approx_eq (W*W)*v full(W)*(full(W)*v) - @test_approx_eq (W*W')*v full(W)*(full(W)*v) - @test_approx_eq W[1:3,1:3]*v[1:3] full(W)[1:3,1:3]*v[1:3] - @test_approx_eq full(WoodburyMatrices.conjm(W, R)) R*full(W)*R' - @test_approx_eq full((copy(W)'W)*v) full(W)*(full(W)*v) - @test_approx_eq full(W + A) full(W)+full(A) - @test_approx_eq full(A + W) full(W)+full(A) + @test (2*W)*v ≈ 2*(W*v) + @test (W*2)*v ≈ 2*(W*v) + @test (W'W)*v ≈ full(W)*(full(W)*v) + @test (W*W)*v ≈ full(W)*(full(W)*v) + @test (W*W')*v ≈ full(W)*(full(W)*v) + @test W[1:3,1:3]*v[1:3] ≈ full(W)[1:3,1:3]*v[1:3] + @test full(WoodburyMatrices.conjm(W, R)) ≈ R*full(W)*R' + @test full((copy(W)'W)*v) ≈ full(W)*(full(W)*v) + @test full(W + A) ≈ full(W)+full(A) + @test full(A + W) ≈ full(W)+full(A) end v = rand(n,1) - W2 = convert(Woodbury, W) - @test_approx_eq full(W2) full(W) + @test full(W2) ≈ full(W) if elty != Int - @test_approx_eq inv(W)*v inv(full(W))*v - @test_approx_eq W\v inv(full(W))*v - @test_approx_eq liftFactor(W)(v) inv(W)*v - @test_approx_eq WoodburyMatrices.partialInv(W)[1] inv(W).B - @test_approx_eq WoodburyMatrices.partialInv(W)[2] inv(W).D - @test_approx_eq det(W) det(full(W)) + @test inv(W)*v ≈ inv(full(W))*v + @test W\v ≈ inv(full(W))*v + @test liftFactor(W)(v) ≈ inv(W)*v + @test WoodburyMatrices.partialInv(W)[1] ≈ inv(W).B + @test WoodburyMatrices.partialInv(W)[2] ≈ inv(W).D + @test det(W) ≈ det(full(W)) end end @@ -84,11 +82,11 @@ for elty in (Float32, Float64, Complex64, Complex128, Int) elty = Float64 a1 = rand(n); B1 = rand(n,2); D1 = rand(2,2); v = rand(n) - a2 = rand(n); B2 = rand(n,2); D2 = rand(2,2); + a2 = rand(n); B2 = rand(n,2); D2 = rand(2,2); if elty == Int v = rand(1:100, n) - + a1 = rand(1:100, n) B1 = rand(1:100, 2, n) D1 = rand(1:100, 2, 2) @@ -109,7 +107,7 @@ for elty in (Float32, Float64, Complex64, Complex128, Int) end ε = eps(abs2(float(one(elty)))) - + # Woodbury A1 = diagm(a1) A2 = diagm(a2) @@ -121,9 +119,9 @@ for elty in (Float32, Float64, Complex64, Complex128, Int) W2r = SymWoodbury(A2, B2[:,1][:], 3.) for (W1, W2) = ((W1,W2), (W1r, W2), (W1, W2r), (W1r,W2r)) - @test_approx_eq (W1 + W2)*v (full(W1) + full(W2))*v - @test_approx_eq (full(W1) + W2)*v (full(W1) + full(W2))*v - @test_approx_eq (W1 + 2*diagm(a1))*v (full(W1) + full(2*diagm(a1)))*v + @test (W1 + W2)*v ≈ (full(W1) + full(W2))*v + @test (full(W1) + W2)*v ≈ (full(W1) + full(W2))*v + @test (W1 + 2*diagm(a1))*v ≈ (full(W1) + full(2*diagm(a1)))*v @test_throws MethodError W1*W2 end @@ -142,18 +140,18 @@ V = randn(n,1) @test size(W,1) == n @test size(W,2) == n -@test_approx_eq inv(W)*v inv(full(W))*v -@test_approx_eq (2*W)*v 2*(W*v) -@test_approx_eq (W'W)*v full(W)*(full(W)*v) -@test_approx_eq (W*W)*v full(W)*(full(W)*v) -@test_approx_eq (W*W')*v full(W)*(full(W)*v) -@test_approx_eq liftFactor(W)(v) inv(W)*v - -@test_approx_eq inv(W)*V inv(full(W))*V -@test_approx_eq (2*W)*V 2*(W*V) -@test_approx_eq (W'W)*V full(W)*(full(W)*V) -@test_approx_eq (W*W)*V full(W)*(full(W)*V) -@test_approx_eq (W*W')*V full(W)*(full(W)*V) +@test inv(W)*v ≈ inv(full(W))*v +@test (2*W)*v ≈ 2*(W*v) +@test (W'W)*v ≈ full(W)*(full(W)*v) +@test (W*W)*v ≈ full(W)*(full(W)*v) +@test (W*W')*v ≈ full(W)*(full(W)*v) +@test liftFactor(W)(v) ≈ inv(W)*v + +@test inv(W)*V ≈ inv(full(W))*V +@test (2*W)*V ≈ 2*(W*V) +@test (W'W)*V ≈ full(W)*(full(W)*V) +@test (W*W)*V ≈ full(W)*(full(W)*V) +@test (W*W')*V ≈ full(W)*(full(W)*V) # Mismatched sizes @test_throws DimensionMismatch SymWoodbury(rand(5,5),rand(5,2),rand(2,3))