From 1915fcc2d33a0510812ab7d4c9b2b29e6635f12d Mon Sep 17 00:00:00 2001 From: Ni Date: Wed, 21 Jun 2017 16:30:28 -0500 Subject: [PATCH] Add conversion & unit tests for pMatrix & ddiMatrix --- ChangeLog | 5 +++ inst/include/RcppArmadilloAs.h | 70 +++++++++++++++++++++++++++++++--- inst/unitTests/runit.sparse.R | 30 +++++++++++++++ 3 files changed, 100 insertions(+), 5 deletions(-) diff --git a/ChangeLog b/ChangeLog index 82a3702b..a0a81e9b 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,8 @@ +2017-06-21 Binxiang Ni + + * inst/include/RcppArmadilloAs.h: Add conversion for pMatrix & ddiMatrix + * inst/unitTests/runit.sparse.R: Add unit tests for the conversion for pMatrix & ddiMatrix + 2017-06-19 Dirk Eddelbuettel * DESCRIPTION (Version, Date): Roll minor version diff --git a/inst/include/RcppArmadilloAs.h b/inst/include/RcppArmadilloAs.h index fe620d95..b7a50c87 100644 --- a/inst/include/RcppArmadilloAs.h +++ b/inst/include/RcppArmadilloAs.h @@ -79,7 +79,7 @@ namespace traits { }; // 14 June 2017 - // Add support for dgCMatrix, dtCMatrix and dsCMatrix + // Add support for sparse matrices other than dgCMatrix template class Exporter< arma::SpMat > { public: @@ -249,7 +249,7 @@ namespace traits { int nnz = rx.size(); IntegerVector i = IntegerVector(nnz); IntegerVector p = IntegerVector(ncol + 1); - NumericVector x = NumericVector(nnz); + Vector x = Vector(nnz); // Count the nnz in each column for(int n = 0; n < nnz; n++){ @@ -300,7 +300,7 @@ namespace traits { int nnz = rx.size(); IntegerVector i = IntegerVector(nnz); IntegerVector p = IntegerVector(ncol + 1); - NumericVector x = NumericVector(nnz); + Vector x = Vector(nnz); // Count the nnz in each column for(int n = 0; n < nnz; n++){ @@ -355,7 +355,7 @@ namespace traits { int nnz = rx.size(); IntegerVector i = IntegerVector(nnz); IntegerVector p = IntegerVector(ncol + 1); - NumericVector x = NumericVector(nnz); + Vector x = Vector(nnz); // Count the nnz in each column for(int n = 0; n < nnz; n++){ @@ -403,6 +403,65 @@ namespace traits { res = symmatl(res); } } + else if (type == "pMatrix") { + std::vector i; + IntegerVector p(ncol + 1); + IntegerVector x(ncol, 1); + IntegerVector perm = mat.slot("perm"); + + // Sort the row number by the column number + std::map colrow; + for(int tmp = 0; tmp < perm.size(); tmp++){ + colrow[perm[tmp]] = tmp; + } + + // Calculate i + for(std::map::iterator tmp = colrow.begin(); tmp != colrow.end(); tmp++){ + i.push_back(tmp -> second); + } + + // Calculate p + for(int tmp = 0; tmp < p.size(); tmp++){ + p[tmp] = tmp; + } + + // Making space for the elements + res.mem_resize(static_cast(x.size())); + + // Copying elements + std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices)); + std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs)); + std::copy(x.begin(), x.end(), arma::access::rwp(res.values)); + } + else if (type == "ddiMatrix") { + IntegerVector i(ncol); + IntegerVector p(ncol+1); + std::string diag = Rcpp::as(mat.slot("diag")); + Vector x = no_init(ncol); + if (diag == "U") { + x.fill(1); + } else { + x = Vector(mat.slot("x")); + } + + // Calculate i + for(int tmp = 0; tmp < i.size(); tmp++){ + i[tmp] = tmp; + } + + // Calculate p + for(int tmp = 0; tmp < p.size(); tmp++){ + p[tmp] = tmp; + } + + // Making space for the elements + res.mem_resize(static_cast(x.size())); + + // Copying elements + std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices)); + std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs)); + std::copy(x.begin(), x.end(), arma::access::rwp(res.values)); + } // Setting the sentinel arma::access::rw(res.col_ptrs[static_cast(ncol + 1)]) = @@ -410,7 +469,8 @@ namespace traits { return res; } - + + private: S4 mat ; } ; diff --git a/inst/unitTests/runit.sparse.R b/inst/unitTests/runit.sparse.R index acc95d34..e61ba70f 100644 --- a/inst/unitTests/runit.sparse.R +++ b/inst/unitTests/runit.sparse.R @@ -208,4 +208,34 @@ if (.runThisTest) { dsr <- methods::as(dsc, "RsparseMatrix") checkEquals(dgc, asSpMat(dsr), msg="asSpMat") } + + test.p2dgc <- function() { + mtxt <- c("0 1 0", + "0 0 1", + "1 0 0") + M <- as.matrix(read.table(textConnection(mtxt))) + dimnames(M) <- NULL + dgc <- methods::as(M, "dgCMatrix") + p <- as(as.integer(c(2,3,1)), "pMatrix") + checkEquals(dgc, asSpMat(p), msg="asSpMat") + } + + test.ddi2dgc <- function() { + mtxt <- c("1 0 0", + "0 1 0", + "0 0 1") + M <- as.matrix(read.table(textConnection(mtxt))) + dimnames(M) <- NULL + dgc <- methods::as(M, "dgCMatrix") + ddi <- methods::as(M, "diagonalMatrix") + checkEquals(dgc, asSpMat(ddi), msg="asSpMat") + + mtxt <- c("10 0", + "0 1") + M <- as.matrix(read.table(textConnection(mtxt))) + dimnames(M) <- NULL + dgc <- methods::as(M, "dgCMatrix") + ddi <- methods::as(M, "diagonalMatrix") + checkEquals(dgc, asSpMat(ddi), msg="asSpMat") + } }