diff --git a/src/TiledArray/math/linalg/rank-local.cpp b/src/TiledArray/math/linalg/rank-local.cpp index e26ec3da59..5413ad0e5f 100644 --- a/src/TiledArray/math/linalg/rank-local.cpp +++ b/src/TiledArray/math/linalg/rank-local.cpp @@ -140,35 +140,48 @@ void heig(Matrix& A, Matrix& B, std::vector& W) { } template -void svd(Matrix& A, std::vector& S, Matrix* U, Matrix* VT) { +void svd(Job jobu, Job jobvt, Matrix& A, std::vector& S, Matrix* U, Matrix* VT) { integer m = A.rows(); integer n = A.cols(); + integer k = std::min(m, n); T* a = A.data(); integer lda = A.rows(); - S.resize(std::min(m, n)); + S.resize(k); T* s = S.data(); - auto jobu = lapack::Job::NoVec; - T* u = nullptr; - integer ldu = m; - if (U) { - jobu = lapack::Job::AllVec; - U->resize(m, n); + T* u = nullptr; + T* vt = nullptr; + integer ldu = 1, ldvt = 1; + if( (jobu == Job::SomeVec or jobu == Job::AllVec) and (not U) ) + TA_LAPACK_ERROR("Requested out-of-place right singular vectors with null U input"); + if( (jobvt == Job::SomeVec or jobvt == Job::AllVec) and (not VT) ) + TA_LAPACK_ERROR("Requested out-of-place left singular vectors with null VT input"); + + if( jobu == Job::SomeVec ) { + U->resize(m, k); u = U->data(); - ldu = U->rows(); + ldu = m; } - auto jobvt = lapack::Job::NoVec; - T* vt = nullptr; - integer ldvt = n; - if (VT) { - jobvt = lapack::Job::AllVec; - VT->resize(n, m); + if( jobu == Job::AllVec ) { + U->resize(m, m); + u = U->data(); + ldu = m; + } + + if( jobvt == Job::SomeVec ) { + VT->resize(k, n); vt = VT->data(); - ldvt = VT->rows(); + ldvt = k; } + if( jobvt == Job::AllVec ) { + VT->resize(n, n); + vt = VT->data(); + ldvt = n; + } + TA_LAPACK(gesvd, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt); } @@ -194,15 +207,15 @@ void lu_inv(Matrix& A) { TA_LAPACK(getri, n, a, lda, ipiv.data()); } -#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \ - template void cholesky(MATRIX&); \ - template void cholesky_linv(MATRIX&); \ - template void cholesky_solve(MATRIX&, MATRIX&); \ - template void cholesky_lsolve(Op, MATRIX&, MATRIX&); \ - template void heig(MATRIX&, VECTOR&); \ - template void heig(MATRIX&, MATRIX&, VECTOR&); \ - template void svd(MATRIX&, VECTOR&, MATRIX*, MATRIX*); \ - template void lu_solve(MATRIX&, MATRIX&); \ +#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \ + template void cholesky(MATRIX&); \ + template void cholesky_linv(MATRIX&); \ + template void cholesky_solve(MATRIX&, MATRIX&); \ + template void cholesky_lsolve(Op, MATRIX&, MATRIX&); \ + template void heig(MATRIX&, VECTOR&); \ + template void heig(MATRIX&, MATRIX&, VECTOR&); \ + template void svd(Job,Job,MATRIX&, VECTOR&, MATRIX*, MATRIX*); \ + template void lu_solve(MATRIX&, MATRIX&); \ template void lu_inv(MATRIX&); TA_LAPACK_EXPLICIT(Matrix, std::vector); diff --git a/src/TiledArray/math/linalg/rank-local.h b/src/TiledArray/math/linalg/rank-local.h index 144c844e3c..87d4b44a56 100644 --- a/src/TiledArray/math/linalg/rank-local.h +++ b/src/TiledArray/math/linalg/rank-local.h @@ -10,6 +10,8 @@ namespace TiledArray::math::linalg::rank_local { +using Job = ::lapack::Job; + template using Matrix = ::Eigen::Matrix; @@ -35,7 +37,14 @@ template void heig(Matrix &A, Matrix &B, std::vector &W); template -void svd(Matrix &A, std::vector &S, Matrix *U, Matrix *VT); +void svd(Job jobu, Job jobvt, Matrix &A, std::vector &S, Matrix *U, Matrix *VT); + +template +void svd(Matrix &A, std::vector &S, Matrix *U, Matrix *VT) { + svd( U ? Job::SomeVec : Job::NoVec, + VT ? Job::SomeVec : Job::NoVec, + A, S, U, VT ); +} template void lu_solve(Matrix &A, Matrix &B);