Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accomodate upcoming change in package Matrix (fixes #415) #417

Merged
merged 3 commits into from Jun 5, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
209 changes: 73 additions & 136 deletions inst/include/RcppArmadillo/interface/RcppArmadilloAs.h
Expand Up @@ -123,38 +123,38 @@ namespace traits {

// Get the type of sparse matrix
std::string type = Rcpp::as<std::string>(mat.slot("class"));

if (type == "dgCMatrix" || mat.is("dgCMatrix")) {
IntegerVector i = mat.slot("i");
IntegerVector p = mat.slot("p");
Vector<RTYPE> x = mat.slot("x");

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
#define DO_RESULT \
do { \
/* Allocate: */ \
res.mem_resize(static_cast<unsigned>(x.size())); \
\
/* To access arrays internal to SpMat class: */ \
res.sync(); \
\
/* Copy: */ \
std::copy(i.begin(), i.end(), \
arma::access::rwp(res.row_indices)); \
std::copy(p.begin(), p.end(), \
arma::access::rwp(res.col_ptrs)); \
std::copy(x.begin(), x.end(), \
arma::access::rwp(res.values)); \
} while (0)

DO_RESULT;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old school :) Maybe we have too much C around R to not let go of macros.

}
else if (type == "dtCMatrix" || mat.is("dtCMatrix")) {
IntegerVector i = mat.slot("i");
IntegerVector p = mat.slot("p");
Vector<RTYPE> x = mat.slot("x");
std::string diag = Rcpp::as<std::string>(mat.slot("diag"));

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
DO_RESULT;

if (diag == "U") {
res.diag().ones();
Expand All @@ -166,16 +166,7 @@ namespace traits {
Vector<RTYPE> x = mat.slot("x");
std::string uplo = Rcpp::as<std::string>(mat.slot("uplo"));

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
DO_RESULT;

if (uplo == "U") {
res = symmatu(res);
Expand Down Expand Up @@ -250,16 +241,7 @@ namespace traits {
last = tmp;
}

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
DO_RESULT;
}
else if (type == "dtRMatrix" || mat.is("dtRMatrix")) {
IntegerVector rj = mat.slot("j");
Expand Down Expand Up @@ -304,16 +286,7 @@ namespace traits {
last = tmp;
}

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
DO_RESULT;

if (diag == "U"){
res.diag().ones();
Expand Down Expand Up @@ -362,16 +335,7 @@ namespace traits {
last = tmp;
}

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
DO_RESULT;

if (uplo == "U") {
res = symmatu(res);
Expand All @@ -380,82 +344,61 @@ namespace traits {
}
}
else if (type == "indMatrix" || mat.is("indMatrix")) {
std::vector<int> i;
IntegerVector p(ncol + 1);
IntegerVector x(nrow, 1);
IntegerVector perm = mat.slot("perm");

typedef std::pair<int, int> Key;
typedef std::set<Key> Set;
Set permiSet;

// Sort i;
int nnz = perm.size();
for(int tmp = 0; tmp < nnz; tmp++){
Key permi(perm[tmp], tmp);
permiSet.insert(permi);
}

for(Set::iterator tmp = permiSet.begin(); tmp != permiSet.end(); tmp++){
i.push_back(tmp -> second);
}

// Count the number of nnz in each column
for(int idx = 0; idx < nnz; idx++){
int col = perm[idx];
p[col]++;
}

// Cumsum p
for(int col = 0, cumsum = 0; col < ncol + 1; col++){
cumsum += p[col];
p[col] = cumsum;
IntegerVector p(ncol + 1);
IntegerVector i(perm.size());
IntegerVector x(perm.size());

if (!mat.hasSlot("margin") ||
as<IntegerVector>(mat.slot("margin"))[0] == 1) {
int *work = reinterpret_cast<int *>(
R_alloc((std::size_t) ncol, sizeof(int)));
std::memset(work, 0, ncol * sizeof(int));
for (int ii = 0; ii < nrow; ++ii)
work[perm[ii] - 1]++;
for (int jj = 0; jj < ncol; ++jj) {
p[jj + 1] = p[jj] + work[jj];
work[jj] = p[jj];
}
for (int ii = 0; ii < nrow; ++ii) {
i[work[perm[ii] - 1]++] = ii;
x[ii] = 1;
}
} else {
for (int jj = 0; jj < ncol; ++jj) {
p[jj] = jj;
i[jj] = perm[jj] - 1;
x[jj] = 1;
}
p[ncol] = ncol;
}

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
DO_RESULT;
}
else if (type == "pMatrix" || mat.is("pMatrix")) {
std::vector<int> i;
IntegerVector p(ncol + 1);
IntegerVector x(ncol, 1);
IntegerVector perm = mat.slot("perm");

// Sort the row number by the column number
typedef std::map <int, int> Map;
Map colrow;
for(int tmp = 0; tmp < perm.size(); tmp++){
colrow[perm[tmp]] = tmp;
}

// Calculate i
for(Map::iterator tmp = colrow.begin(); tmp != colrow.end(); tmp++){
i.push_back(tmp -> second);
}

// Calculate p
for(int tmp = 0; tmp < p.size(); tmp++){
p[tmp] = tmp;
IntegerVector p(ncol + 1);
IntegerVector i(ncol);
IntegerVector x(ncol);

if (!mat.hasSlot("margin") ||
as<IntegerVector>(mat.slot("margin"))[0] == 1) {
for (int jj = 0; jj < ncol; ++jj)
i[perm[jj] - 1] = jj;
for (int jj = 0; jj < ncol; ++jj) {
p[jj] = jj;
x[jj] = 1;
}
} else {
for (int jj = 0; jj < ncol; ++jj) {
p[jj] = jj;
i[jj] = perm[jj] - 1;
x[jj] = 1;
}
}
p[ncol] = ncol;

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));

// In order to access the internal arrays of the SpMat class
res.sync();

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
DO_RESULT;
}
else if (type == "ddiMatrix" || mat.is("ddiMatrix")) {
std::vector<int> i;
Expand Down Expand Up @@ -484,16 +427,10 @@ namespace traits {
p.push_back(tmpp);
}

// Making space for the elements
res.mem_resize(static_cast<unsigned>(x.size()));
DO_RESULT;

// In order to access the internal arrays of the SpMat class
res.sync();
#undef DO_RESULT

// Copying elements
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
}
else {
Rcpp::stop(type + " is not supported.");
Expand Down