Permalink
Browse files

Improved tridiagonal and new Woodbury matrix algebra

- Converts Tridiagonal to being Lapack-compatible
- Wraps Lapack's major tridiagonal routines
- Adds a new "Woodbury" type for solving equations using the Woodbury matrix identity
- Moves the functionality into appropriate functions in base
- Provides a set of tests for this functionality
  • Loading branch information...
timholy committed Aug 30, 2012
1 parent 2a011e3 commit cbaad366a2a58cc22bfd31838b99327c8329098f
Showing with 594 additions and 108 deletions.
  1. +6 −0 base/export.jl
  2. +45 −0 base/factorizations.jl
  3. +3 −1 base/linalg.jl
  4. +184 −0 base/linalg_lapack.jl
  5. +299 −0 base/linalg_specialized.jl
  6. +1 −0 base/sysimg.jl
  7. +0 −61 extras/linalg_sparse.jl
  8. +0 −45 extras/sparse.jl
  9. +56 −0 test/lapack.jl
  10. +0 −1 test/sparse.jl
View
@@ -81,15 +81,20 @@ export
SubOrDArray,
SubString,
TransformedString,
+ Tridiagonal,
VecOrMat,
Vector,
VersionNumber,
WeakKeyDict,
+ Woodbury,
Zip,
Stat,
Factorization,
Cholesky,
LU,
+ LUTridiagonal,
+ LDLT,
+ LDLTTridiagonal,
QR,
QRP,
@@ -634,6 +639,7 @@ export
randsym,
rank,
rref,
+ solve,
svd,
svdvals,
trace,
View
@@ -281,3 +281,48 @@ end
##ToDo: Add methods for rank(A::QRP{T}) and adjust the (\) method accordingly
## Add rcond methods for Cholesky, LU, QR and QRP types
## Lower priority: Add LQ, QL and RQ factorizations
+
+
+#### Factorizations for Tridiagonal ####
+type LDLTTridiagonal{T} <: Factorization{T}
+ D::Vector{T}
+ E::Vector{T}
+end
+function LDLTTridiagonal{T<:LapackScalar}(A::Tridiagonal{T})
+ D = copy(A.d)
+ E = copy(A.dl)
+ _jl_lapack_pttrf(D, E)
+ LDLTTridiagonal(D, E)
+end
+LDLT(A::Tridiagonal) = LDLTTridiagonal(A)
+
+(\){T<:LapackScalar}(C::LDLTTridiagonal{T}, B::StridedVecOrMat{T}) =
+ _jl_lapack_pttrs(C.D, C.E, copy(B))
+
+type LUTridiagonal{T} <: Factorization{T}
+ lu::Tridiagonal{T}
+ ipiv::Vector{Int32}
+ function LUTridiagonal(lu::Tridiagonal{T}, ipiv::Vector{Int32})
+ m, n = size(lu)
+ m == numel(ipiv) ? new(lu, ipiv) : error("LU: dimension mismatch")
+ end
+end
+show(io, lu::LUTridiagonal) = print(io, "LU decomposition of ", summary(lu.lu))
+
+function LU{T<:LapackScalar}(A::Tridiagonal{T})
+ lu, ipiv = _jl_lapack_gttrf(copy(A))
+ LUTridiagonal{T}(lu, ipiv)
+end
+
+function lu(A::Tridiagonal)
+ error("lu(A) is not defined when A is Tridiagonal. Use LU(A) instead.")
+end
+
+function det(lu::LUTridiagonal)
+ prod(lu.lu.d) * (bool(sum(lu.ipiv .!= 1:n) % 2) ? -1 : 1)
+end
+
+det(A::Tridiagonal) = det(LU(A))
+
+(\){T<:LapackScalar}(lu::LUTridiagonal{T}, B::StridedVecOrMat{T}) =
+ _jl_lapack_gttrs('N', lu.lu, lu.ipiv, copy(B))
View
@@ -1,4 +1,6 @@
-## linalg.jl: Basic Linear Algebra interface specifications ##
+## linalg.jl: Basic Linear Algebra interface specifications and
+## specialized matrix types
+
#
# This file mostly contains commented functions which are supposed
# to be defined in type-specific linalg_<type>.jl files.
View
@@ -885,3 +885,187 @@ end
expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm!(copy(A))
expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A))
+
+#### Tridiagonal matrix routines ####
+function \{T<:LapackScalar}(M::Tridiagonal{T}, rhs::StridedVecOrMat{T})
+ if stride(rhs, 1) == 1
+ x = copy(rhs)
+ Mc = copy(M)
+ Mlu, x = _jl_lapack_gtsv(Mc, x)
+ return x
+ end
+ solve(M, rhs) # use the Julia "fallback"
+end
+
+eig(M::Tridiagonal) = _jl_lapack_stev('V', copy(M))
+
+# Decompositions
+for (gttrf, pttrf, elty) in
+ ((:dgttrf_,:dpttrf_,:Float64),
+ (:sgttrf_,:spttrf_,:Float32),
+ (:zgttrf_,:zpttrf_,:Complex128),
+ (:cgttrf_,:cpttrf_,:Complex64))
+ @eval begin
+ function _jl_lapack_gttrf(M::Tridiagonal{$elty})
+ info = zero(Int32)
+ n = int32(length(M.d))
+ ipiv = Array(Int32, n)
+ ccall(dlsym(_jl_liblapack, $string(gttrf)),
+ Void,
+ (Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
+ Ptr{Int32}, Ptr{Int32}),
+ &n, M.dl, M.d, M.du, M.dutmp, ipiv, &info)
+ if info != 0 throw(LapackException(info)) end
+ M, ipiv
+ end
+ function _jl_lapack_pttrf(D::Vector{$elty}, E::Vector{$elty})
+ info = zero(Int32)
+ n = int32(length(D))
+ if length(E) != n-1
+ error("subdiagonal must be one element shorter than diagonal")
+ end
+ ccall(dlsym(_jl_liblapack, $string(pttrf)),
+ Void,
+ (Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{Int32}),
+ &n, D, E, &info)
+ if info != 0 throw(LapackException(info)) end
+ D, E
+ end
+ end
+end
+# Direct solvers
+for (gtsv, ptsv, elty) in
+ ((:dgtsv_,:dptsv_,:Float64),
+ (:sgtsv_,:sptsv,:Float32),
+ (:zgtsv_,:zptsv,:Complex128),
+ (:cgtsv_,:cptsv,:Complex64))
+ @eval begin
+ function _jl_lapack_gtsv(M::Tridiagonal{$elty}, B::StridedVecOrMat{$elty})
+ if stride(B,1) != 1
+ error("_jl_lapack_gtsv: matrix columns must have contiguous elements");
+ end
+ info = zero(Int32)
+ n = int32(length(M.d))
+ nrhs = int32(size(B, 2))
+ ldb = int32(stride(B, 2))
+ ccall(dlsym(_jl_liblapack, $string(gtsv)),
+ Void,
+ (Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
+ Ptr{Int32}, Ptr{Int32}),
+ &n, &nrhs, M.dl, M.d, M.du, B, &ldb, &info)
+ if info != 0 throw(LapackException(info)) end
+ M, B
+ end
+ function _jl_lapack_ptsv(M::Tridiagonal{$elty}, B::StridedVecOrMat{$elty})
+ if stride(B,1) != 1
+ error("_jl_lapack_ptsv: matrix columns must have contiguous elements");
+ end
+ info = zero(Int32)
+ n = int32(length(M.d))
+ nrhs = int32(size(B, 2))
+ ldb = int32(stride(B, 2))
+ ccall(dlsym(_jl_liblapack, $string(ptsv)),
+ Void,
+ (Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
+ Ptr{Int32}, Ptr{Int32}),
+ &n, &nrhs, M.d, M.dl, B, &ldb, &info)
+ if info != 0 throw(LapackException(info)) end
+ M, B
+ end
+ end
+end
+# Solvers using decompositions
+for (gttrs, pttrs, elty) in
+ ((:dgttrs_,:dpttrs_,:Float64),
+ (:sgttrs_,:spttrs,:Float32),
+ (:zgttrs_,:zpttrs,:Complex128),
+ (:cgttrs_,:cpttrs,:Complex64))
+ @eval begin
+ function _jl_lapack_gttrs(trans::LapackChar, M::Tridiagonal{$elty}, ipiv::Vector{Int32}, B::StridedVecOrMat{$elty})
+ if stride(B,1) != 1
+ error("_jl_lapack_gttrs: matrix columns must have contiguous elements");
+ end
+ info = zero(Int32)
+ n = int32(length(M.d))
+ nrhs = int32(size(B, 2))
+ ldb = int32(stride(B, 2))
+ ccall(dlsym(_jl_liblapack, $string(gttrs)),
+ Void,
+ (Ptr{Uint8}, Ptr{Int32}, Ptr{Int32},
+ Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
+ Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}),
+ &trans, &n, &nrhs, M.dl, M.d, M.du, M.dutmp, ipiv, B, &ldb, &info)
+ if info != 0 throw(LapackException(info)) end
+ B
+ end
+ function _jl_lapack_pttrs(D::Vector{$elty}, E::Vector{$elty}, B::StridedVecOrMat{$elty})
+ if stride(B,1) != 1
+ error("_jl_lapack_pttrs: matrix columns must have contiguous elements");
+ end
+ info = zero(Int32)
+ n = int32(length(D))
+ if length(E) != n-1
+ error("subdiagonal must be one element shorter than diagonal")
+ end
+ nrhs = int32(size(B, 2))
+ ldb = int32(stride(B, 2))
+ ccall(dlsym(_jl_liblapack, $string(pttrs)),
+ Void,
+ (Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
+ Ptr{Int32}, Ptr{Int32}),
+ &n, &nrhs, D, E, B, &ldb, &info)
+ if info != 0 throw(LapackException(info)) end
+ B
+ end
+ end
+end
+# Eigenvalue-eigenvector (symmetric only)
+for (stev, elty) in
+ ((:dstev_,:Float64),
+ (:sstev_,:Float32),
+ (:zstev_,:Complex128),
+ (:cstev_,:Complex64))
+ @eval begin
+ function _jl_lapack_stev(Z::Array, M::Tridiagonal{$elty})
+ n = int32(length(M.d))
+ if isempty(Z)
+ job = 'N'
+ ldz = 1
+ work = Array($elty, 0)
+ Ztmp = work
+ else
+ if stride(Z,1) != 1
+ error("_jl_lapack_stev: eigenvector matrix columns must have contiguous elements");
+ end
+ if size(Z, 1) != n
+ error("_jl_lapack_stev: eigenvector matrix columns are not of the correct size")
+ end
+ Ztmp = Z
+ job = 'V'
+ ldz = int32(stride(Z, 2))
+ work = Array($elty, max(1, 2*n-2))
+ end
+ info = zero(Int32)
+ ccall(dlsym(_jl_liblapack, $string(stev)),
+ Void,
+ (Ptr{Uint8}, Ptr{Int32},
+ Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
+ Ptr{Int32}, Ptr{$elty}, Ptr{Int32}),
+ &job, &n, M.d, M.dl, Ztmp, &ldz, work, &info)
+ if info != 0 throw(LapackException(info)) end
+ M.d
+ end
+ end
+end
+function _jl_lapack_stev(job::LapackChar, M::Tridiagonal)
+ if job == 'N' || job == 'n'
+ Z = []
+ elseif job == 'V' || job == 'v'
+ n = length(M.d)
+ Z = Array(eltype(M), n, n)
+ else
+ error("Job type not recognized")
+ end
+ D = _jl_lapack_stev(Z, M)
+ return D, Z
+end
Oops, something went wrong.

0 comments on commit cbaad36

Please sign in to comment.