Skip to content
86 changes: 84 additions & 2 deletions inst/include/RcppArmadilloAs.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ namespace traits {
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
}
else if (type == "dtCMatrix") {
// The following 3 lines might be duplicate, but when the type == dgT or dgR, we have to include the lines inside the conditional statements rather than outside.
IntegerVector i = mat.slot("i");
IntegerVector p = mat.slot("p");
Vector<RTYPE> x = mat.slot("x");
Expand All @@ -130,7 +129,6 @@ namespace traits {
}
}
else if (type == "dsCMatrix") {
// The following 3 lines might be duplicate, but when the type == dgT or dgR, we have to include the lines inside the conditional statements rather than outside.
IntegerVector i = mat.slot("i");
IntegerVector p = mat.slot("p");
Vector<RTYPE> x = mat.slot("x");
Expand All @@ -150,6 +148,90 @@ namespace traits {
res = symmatl(res);
}
}
else if (type == "dgTMatrix") {
IntegerVector tj = mat.slot("j");
IntegerVector i = mat.slot("i");
Vector<RTYPE> x = mat.slot("x");
IntegerVector p = IntegerVector(ncol + 1);

int nnz = x.size();
// Count the number of nnz in each column
for(int idx = 0; idx < nnz; idx++){
int col = tj[idx];
p[col + 1]++;
}

// Cumsum p
for(int col = 0, cumsum = 0; col < ncol + 1; col++){
cumsum += p[col];
p[col] = cumsum;
}

res.mem_resize(static_cast<unsigned>(x.size()));
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));
}
else if (type == "dtTMatrix") {
IntegerVector tj = mat.slot("j");
IntegerVector i = mat.slot("i");
Vector<RTYPE> x = mat.slot("x");
IntegerVector p = IntegerVector(ncol + 1);
std::string diag = Rcpp::as<std::string>(mat.slot("diag"));

int nnz = x.size();
// Count the number of nnz in each column
for(int idx = 0; idx < nnz; idx++){
int col = tj[idx];
p[col + 1]++;
}

// Cumsum p
for(int col = 0, cumsum = 0; col < ncol + 1; col++){
cumsum += p[col];
p[col] = cumsum;
}

res.mem_resize(static_cast<unsigned>(x.size()));
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));

if (diag == "U"){
res.diag().ones();
}
}
else if (type == "dsTMatrix") {
IntegerVector tj = mat.slot("j");
IntegerVector i = mat.slot("i");
Vector<RTYPE> x = mat.slot("x");
IntegerVector p = IntegerVector(ncol + 1);
std::string uplo = Rcpp::as<std::string>(mat.slot("uplo"));

int nnz = x.size();
// Count the number of nnz in each column
for(int idx = 0; idx < nnz; idx++){
int col = tj[idx];
p[col + 1]++;
}

// Cumsum p
for(int col = 0, cumsum = 0; col < ncol + 1; col++){
cumsum += p[col];
p[col] = cumsum;
}

res.mem_resize(static_cast<unsigned>(x.size()));
std::copy(i.begin(), i.end(), arma::access::rwp(res.row_indices));
std::copy(p.begin(), p.end(), arma::access::rwp(res.col_ptrs));
std::copy(x.begin(), x.end(), arma::access::rwp(res.values));

if (uplo == "U") {
res = symmatu(res);
} else {
res = symmatl(res);
}
}

// Setting the sentinel
arma::access::rw(res.col_ptrs[static_cast<unsigned>(ncol + 1)]) =
Expand Down
57 changes: 49 additions & 8 deletions inst/unitTests/runit.sparse.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ if (.runThisTest) {
M <- as.matrix(read.table(textConnection(mtxt)))
dimnames(M) <- NULL
dtc <- Matrix(M, sparse=TRUE)
dgc <- as(dtc, "dgCMatrix")

dgc <- methods::as(dtc, "dgCMatrix")
checkEquals(dgc, asSpMat(dtc), msg="asSpMat")

dtc@diag <- "U"
dgc <- as(dtc, "dgCMatrix")
dgc <- methods::as(dtc, "dgCMatrix")
checkEquals(dgc, asSpMat(dtc), msg="asSpMat")
}

Expand All @@ -118,13 +117,55 @@ if (.runThisTest) {
M <- as.matrix(read.table(textConnection(mtxt)))
dimnames(M) <- NULL
dsc <- Matrix(M, sparse=TRUE)
dgc <- as(dsc, "dgCMatrix")

dgc <- methods::as(dsc, "dgCMatrix")
checkEquals(dgc, asSpMat(dsc), msg="asSpMat")

dsc <- t(dsc)
dgc <- as(dsc, "dgCMatrix")

dgc <- methods::as(dsc, "dgCMatrix")
checkEquals(dgc, asSpMat(dsc), msg="asSpMat")
}
}

test.dgt2dgc <- function() {
dgt <- methods::as(SM, "dgTMatrix")
checkEquals(SM, asSpMat(dgt), msg="asSpMat")
}

test.dtt2dgc <- function() {
mtxt <- c("0 0 0 3",
"0 0 7 0",
"0 0 0 0",
"0 0 0 0")
M <- as.matrix(read.table(textConnection(mtxt)))
dimnames(M) <- NULL
dtc <- Matrix(M, sparse=TRUE)
dgc <- methods::as(dtc, "dgCMatrix")
dtt <- methods::as(dtc, "dtTMatrix")
checkEquals(dgc, asSpMat(dtt), msg="asSpMat")

dtc@diag <- "U"
dgc <- methods::as(dtc, "dgCMatrix")
dtt <- methods::as(dtc, "dtTMatrix")
checkEquals(dgc, asSpMat(dtt), msg="asSpMat")
}

test.dst2dgc <- function() {
mtxt <- c("10 0 1 0 3",
"0 10 0 1 0",
"1 0 10 0 1",
"0 1 0 10 0",
"3 0 1 0 10")
M <- as.matrix(read.table(textConnection(mtxt)))
dimnames(M) <- NULL
dsc <- Matrix(M, sparse=TRUE)
dgc <- methods::as(dsc, "dgCMatrix")
dst <- methods::as(dsc, "dsTMatrix")
checkEquals(dgc, asSpMat(dst), msg="asSpMat")

dsc <- t(dsc)
dgc <- methods::as(dsc, "dgCMatrix")
dst <- methods::as(dsc, "dsTMatrix")
checkEquals(dgc, asSpMat(dst), msg="asSpMat")
}
}