Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions src/TiledArray/math/linalg/rank-local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,35 +140,48 @@ void heig(Matrix<T>& A, Matrix<T>& B, std::vector<T>& W) {
}

template <typename T>
void svd(Matrix<T>& A, std::vector<T>& S, Matrix<T>* U, Matrix<T>* VT) {
void svd(Job jobu, Job jobvt, Matrix<T>& A, std::vector<T>& S, Matrix<T>* U, Matrix<T>* VT) {
integer m = A.rows();
integer n = A.cols();
integer k = std::min(m, n);
T* a = A.data();
integer lda = A.rows();

S.resize(std::min(m, n));
S.resize(k);
T* s = S.data();

auto jobu = lapack::Job::NoVec;
T* u = nullptr;
integer ldu = m;
if (U) {
jobu = lapack::Job::AllVec;
U->resize(m, n);
T* u = nullptr;
T* vt = nullptr;
integer ldu = 1, ldvt = 1;
if( (jobu == Job::SomeVec or jobu == Job::AllVec) and (not U) )
TA_LAPACK_ERROR("Requested out-of-place right singular vectors with null U input");
if( (jobvt == Job::SomeVec or jobvt == Job::AllVec) and (not VT) )
TA_LAPACK_ERROR("Requested out-of-place left singular vectors with null VT input");

if( jobu == Job::SomeVec ) {
U->resize(m, k);
u = U->data();
ldu = U->rows();
ldu = m;
}

auto jobvt = lapack::Job::NoVec;
T* vt = nullptr;
integer ldvt = n;
if (VT) {
jobvt = lapack::Job::AllVec;
VT->resize(n, m);
if( jobu == Job::AllVec ) {
U->resize(m, m);
u = U->data();
ldu = m;
}

if( jobvt == Job::SomeVec ) {
VT->resize(k, n);
vt = VT->data();
ldvt = VT->rows();
ldvt = k;
}

if( jobvt == Job::AllVec ) {
VT->resize(n, n);
vt = VT->data();
ldvt = n;
}

TA_LAPACK(gesvd, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt);
}

Expand All @@ -194,15 +207,15 @@ void lu_inv(Matrix<T>& A) {
TA_LAPACK(getri, n, a, lda, ipiv.data());
}

#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \
template void cholesky(MATRIX&); \
template void cholesky_linv(MATRIX&); \
template void cholesky_solve(MATRIX&, MATRIX&); \
template void cholesky_lsolve(Op, MATRIX&, MATRIX&); \
template void heig(MATRIX&, VECTOR&); \
template void heig(MATRIX&, MATRIX&, VECTOR&); \
template void svd(MATRIX&, VECTOR&, MATRIX*, MATRIX*); \
template void lu_solve(MATRIX&, MATRIX&); \
#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \
template void cholesky(MATRIX&); \
template void cholesky_linv(MATRIX&); \
template void cholesky_solve(MATRIX&, MATRIX&); \
template void cholesky_lsolve(Op, MATRIX&, MATRIX&); \
template void heig(MATRIX&, VECTOR&); \
template void heig(MATRIX&, MATRIX&, VECTOR&); \
template void svd(Job,Job,MATRIX&, VECTOR&, MATRIX*, MATRIX*); \
template void lu_solve(MATRIX&, MATRIX&); \
template void lu_inv(MATRIX&);

TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>);
Expand Down
11 changes: 10 additions & 1 deletion src/TiledArray/math/linalg/rank-local.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

namespace TiledArray::math::linalg::rank_local {

using Job = ::lapack::Job;

template <typename T, int Options = ::Eigen::ColMajor>
using Matrix = ::Eigen::Matrix<T, ::Eigen::Dynamic, ::Eigen::Dynamic, Options>;

Expand All @@ -35,7 +37,14 @@ template <typename T>
void heig(Matrix<T> &A, Matrix<T> &B, std::vector<T> &W);

template <typename T>
void svd(Matrix<T> &A, std::vector<T> &S, Matrix<T> *U, Matrix<T> *VT);
void svd(Job jobu, Job jobvt, Matrix<T> &A, std::vector<T> &S, Matrix<T> *U, Matrix<T> *VT);

template <typename T>
void svd(Matrix<T> &A, std::vector<T> &S, Matrix<T> *U, Matrix<T> *VT) {
svd( U ? Job::SomeVec : Job::NoVec,
VT ? Job::SomeVec : Job::NoVec,
A, S, U, VT );
}

template <typename T>
void lu_solve(Matrix<T> &A, Matrix<T> &B);
Expand Down