Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More 0.7 updates and some fixes #22

Merged
merged 2 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ os:
- linux
- osx
julia:
- 0.6
- nightly
notifications:
email: true
Expand Down
3 changes: 1 addition & 2 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
julia 0.6
Compat 0.49
julia 0.7-alpha
41 changes: 22 additions & 19 deletions src/SymWoodburyMatrices.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import Base:+,*,-,\,^,copy

using Compat.LinearAlgebra.BLAS:gemm!,gemm,axpy!
using Compat.SparseArrays
import Compat.SparseArrays.sparse
using LinearAlgebra.BLAS: gemm!, gemm, axpy!
using SparseArrays
import SparseArrays.sparse

"""
Represents a matrix of the form A + BDBᵀ.
"""
mutable struct SymWoodbury{T,AType, BType, DType} <: AbstractMatrix{T}
A::AType;
B::BType;
D::DType;
struct SymWoodbury{T,AType, BType, DType} <: Factorization{T}
A::AType
B::BType
D::DType
end

"""
Expand Down Expand Up @@ -39,7 +39,7 @@ end
convert(::Type{W}, O::SymWoodbury) where {W<:Woodbury} = Woodbury(O.A, O.B, O.D, O.B')

inv_invD_BtX(invD, B, X) = inv(invD - B'*X);
inv_invD_BtX(invD, B::AbstractVector, X) = inv(invD - vecdot(B,X));
inv_invD_BtX(invD, B::AbstractVector, X) = inv(invD - dot(B,X));

function calc_inv(A, B, D)
W = inv(A);
Expand Down Expand Up @@ -81,7 +81,7 @@ function liftFactorVars(A,B,D)
k = size(B,2)
M = [A B ;
B' -Di ];
M = lufact(M) # ldltfact once it's available.
M = lu(M) # ldltfact once it's available.
return x -> (M\[x; zeros(k,1)])[1:n,:];
end

Expand Down Expand Up @@ -123,18 +123,15 @@ end
# Minor optimization for the rank one case
function plusBDBtx!(o, B::Array{Float64,1}, d::Real, x::Union{Array{Float64,2}, SubArray})
if size(x,2) == 1
axpy!(vecdot(B,x)*d, B, o)
axpy!(dot(B,x)*d, B, o)
else
w = d*gemm('T', 'N' ,reshape(B, size(B,1), 1),x);
gemm!('N','N',1.,B,w,1., o)
end
end

Base.Ac_mul_B(O1::SymWoodbury{T}, x::AbstractVector{T}) where {T} = O1*x
Base.Ac_mul_B(O1::SymWoodbury, x::AbstractMatrix) = O1*x

+(O::SymWoodbury, M::SymWoodbury) = SymWoodbury(O.A + M.A, [O.B M.B],
cat([1,2],O.D,M.D) );
cat(O.D,M.D; dims=(1,2)) );
*(α::Real, O::SymWoodbury) = SymWoodbury(α*O.A, O.B, α*O.D);
*(O::SymWoodbury, α::Real) = SymWoodbury(α*O.A, O.B, α*O.D);
+(M::AbstractMatrix, O::SymWoodbury) = SymWoodbury(O.A + M, O.B, O.D);
Expand Down Expand Up @@ -176,18 +173,24 @@ function *(O1::SymWoodbury, O2::SymWoodbury)
end
end

Base.A_mul_Bc(O1::SymWoodbury, O2::SymWoodbury) = O1*O2

conjm(O::SymWoodbury, M) = SymWoodbury(M*O.A*M', M*O.B, O.D);

Base.getindex(O::SymWoodbury, I::UnitRange, I2::UnitRange) =
SymWoodbury(O.A[I,I], O.B[I,:], O.D);

# This is a slow hack, but generally these matrices aren't sparse.
Compat.SparseArrays.sparse(O::SymWoodbury) = sparse(Matrix(O))
SparseArrays.sparse(O::SymWoodbury) = sparse(Matrix(O))

# returns a pointer to the original matrix, this is consistent with the
# behavior of Symmetric in Base.
Compat.adjoint(O::SymWoodbury) = O
adjoint(O::SymWoodbury) = O

Compat.LinearAlgebra.det(W::SymWoodbury) = det(convert(Woodbury, W))
det(W::SymWoodbury) = det(convert(Woodbury, W))

function show(io::IO, W::SymWoodbury)
println(io, "Symmetric Woodbury factorization:\nA:")
show(io, MIME("text/plain"), W.A)
print(io, "\nB:\n")
Base.print_matrix(IOContext(io,:compact=>true), W.B)
print(io, "\nD: ", W.D)
end
44 changes: 17 additions & 27 deletions src/WoodburyMatrices.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
__precompile__()


module WoodburyMatrices

using Compat
using Compat.LinearAlgebra
import Compat.LinearAlgebra: det, A_ldiv_B!
using LinearAlgebra
import LinearAlgebra: det, ldiv!, mul!, adjoint
import Base: *, \, convert, copy, show, similar, size

# TOOD: remove these definitions once Compat.jl catches up
@static if VERSION <= v"0.7.0-DEV.3185"
const ldiv! = A_ldiv_B!
const mul! = A_mul_B!
end

export Woodbury, SymWoodbury, liftFactor

#### Woodbury matrices ####
"""
`W = Woodbury(A, U, C, V)` creates a matrix `W` identical to `A + U*C*V` whose inverse will be calculated using
the Woodbury matrix identity.
"""
mutable struct Woodbury{T,AType,UType,VType,CType,CpType} <: AbstractMatrix{T}
struct Woodbury{T,AType,UType,VType,CType,CpType} <: Factorization{T}
A::AType
U::UType
C::CType
Expand All @@ -46,10 +38,10 @@ function Woodbury(A, U::AbstractMatrix{T}, C, V::AbstractMatrix{T}) where {T}
end
Cp = inv(convert(Matrix, inv(C) .+ V*(A\U)))
# temporary space for allocation-free solver
tmpN1 = Array{T,1}(uninitialized, N)
tmpN2 = Array{T,1}(uninitialized, N)
tmpk1 = Array{T,1}(uninitialized, k)
tmpk2 = Array{T,1}(uninitialized, k)
tmpN1 = Array{T,1}(undef, N)
tmpN2 = Array{T,1}(undef, N)
tmpk1 = Array{T,1}(undef, k)
tmpk2 = Array{T,1}(undef, k)

# Construct the struct based on the types of the copies,
# not the originals. See: https://github.com/JuliaLang/julia/issues/26294
Expand All @@ -58,27 +50,25 @@ function Woodbury(A, U::AbstractMatrix{T}, C, V::AbstractMatrix{T}) where {T}
end

Woodbury(A, U::Vector{T}, C, V::Matrix{T}) where {T} = Woodbury(A, reshape(U, length(U), 1), C, V)
@static if VERSION <= v"0.7.0-DEV.3040"
Woodbury(A, U::AbstractVector, C, V::RowVector) = Woodbury(A, U, C, Matrix(V))
else
Woodbury(A, U::AbstractVector, C, V::Adjoint) = Woodbury(A, U, C, Matrix(V))
end

Woodbury(A, U::AbstractVector, C, V::Adjoint) = Woodbury(A, U, C, Matrix(V))

size(W::Woodbury) = size(W.A)
size(W::Woodbury, d) = size(W.A, d)

function show(io::IO, W::Woodbury)
println(io, summary(W), ":")
print(io, "A:\n", W.A)
println(io, "Woodbury factorization:\nA:")
show(io, MIME("text/plain"), W.A)
print(io, "\nU:\n")
Base.print_matrix(io, W.U)
Base.print_matrix(IOContext(io, :compact=>true), W.U)
if isa(W.C, Matrix)
print(io, "\nC:\n")
Base.print_matrix(io, W.C)
Base.print_matrix(IOContext(io, :compact=>true), W.C)
else
print(io, "\nC: ", W.C)
end
print(io, "\nV:\n")
Base.print_matrix(io, W.V)
Base.print_matrix(IOContext(io, :compact=>true), W.V)
end

Base.Matrix(W::Woodbury{T}) where {T} = convert(Matrix{T}, W)
Expand All @@ -97,10 +87,10 @@ end

det(W::Woodbury)=det(W.A)*det(W.C)/det(W.Cp)

function A_ldiv_B!(W::Woodbury, B::AbstractVector)
function ldiv!(W::Woodbury, B::AbstractVector)
length(B) == size(W, 1) || throw(DimensionMismatch("Vector length $(length(B)) must match matrix size $(size(W,1))"))
copyto!(W.tmpN1, B)
Alu = lufact(W.A) # Note. This makes an allocation (unless A::LU). Alternative is to destroy W.A.
Alu = lu(W.A) # Note. This makes an allocation (unless A::LU). Alternative is to destroy W.A.
ldiv!(Alu, W.tmpN1)
mul!(W.tmpk1, W.V, W.tmpN1)
mul!(W.tmpk2, W.Cp, W.tmpk1)
Expand Down
9 changes: 3 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using WoodburyMatrices
using Compat
using Compat.LinearAlgebra
using Compat.Random: srand
using Compat.SparseArrays
using Compat.Test
using LinearAlgebra, SparseArrays, Test
using Random: srand

@testset "WoodburyMatrices" begin
srand(123)
Expand Down Expand Up @@ -56,7 +53,7 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64, Int)
@test abs((det(W) - det(F))/det(F)) <= n*cond(F)*ε # Revisit. Condition number is wrong
iWv = similar(iFv)
if elty != Int
iWv = A_ldiv_B!(W, copy(v))
iWv = ldiv!(W, copy(v))
@test iWv ≈ iFv
end
end
Expand Down
3 changes: 1 addition & 2 deletions test/runtests_sym.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Compat
using Compat.Test
using WoodburyMatrices
using Test

srand(123)
n = 5
Expand Down