From 26a2efa65e9f09df358e1021ebf45e3735e2ec6c Mon Sep 17 00:00:00 2001 From: pferrel Date: Mon, 2 Oct 2017 11:39:54 -0700 Subject: [PATCH 1/3] minimum speedup fix --- .../apache/mahout/math/SparseRowMatrix.java | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java index 6e0676911c..1fb4c33d2c 100644 --- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java @@ -19,8 +19,11 @@ import org.apache.mahout.math.flavor.MatrixFlavor; import org.apache.mahout.math.flavor.TraversingStructureEnum; +import org.apache.mahout.math.function.DoubleDoubleFunction; import org.apache.mahout.math.function.Functions; +import java.util.Iterator; + /** * sparse matrix with general element values whose rows are accessible quickly. Implemented as a row * array of either SequentialAccessSparseVectors or RandomAccessSparseVectors. @@ -132,6 +135,40 @@ public Matrix viewPart(int[] offset, int[] size) { return new MatrixView(this, offset, size); } + @Override + public Matrix assign(Matrix other, DoubleDoubleFunction function) { + int rows = rowSize(); + if (rows != other.rowSize()) { + throw new CardinalityException(rows, other.rowSize()); + } + int columns = columnSize(); + if (columns != other.columnSize()) { + throw new CardinalityException(columns, other.columnSize()); + } + for (int row = 0; row < rows; row++) { + if( function.isLikeMult()) { // TODO: is this a sufficient test? + // TODO: this may cause an exception if the row type is not compatible but it is currently guaranteed to be + // a SequentialAccessSparseVector, should "try" here just in case and Warn + // TODO: can we use iterateNonZero on both rows until the index is the same to get better speedup? + Iterator sparseRowIterator = ((SequentialAccessSparseVector) this.rowVectors[row]) + .iterateNonZero(); + // TODO: SASVs have an iterateNonZero that returns zeros, this should not hurt but is far from optimal + // this might perform much better if SparseRowMatrix were backed by RandomAccessSparseVectors, which + // are backed by fastutil hashmaps and the iterateNonZero does only return nonZeros. + while (sparseRowIterator.hasNext()) { + Vector.Element element = sparseRowIterator.next(); + int col = element.index(); + setQuick(row, col, function.apply(element.get(), other.getQuick(row, col))); + } + } else { + for (int col = 0; col < columns; col++) { + setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col))); + } + } + } + return this; + } + @Override public Matrix assignColumn(int column, Vector other) { if (rowSize() != other.size()) { From 9330a2ed6d1211459c57863a5d664377c55aa747 Mon Sep 17 00:00:00 2001 From: pferrel Date: Mon, 2 Oct 2017 12:27:47 -0700 Subject: [PATCH 2/3] minimum speedup fix with cast exception check --- .../apache/mahout/math/SparseRowMatrix.java | 42 +++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java index 1fb4c33d2c..40d0c9e761 100644 --- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java @@ -21,6 +21,8 @@ import org.apache.mahout.math.flavor.TraversingStructureEnum; import org.apache.mahout.math.function.DoubleDoubleFunction; import org.apache.mahout.math.function.Functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Iterator; @@ -33,6 +35,8 @@ public class SparseRowMatrix extends AbstractMatrix { private final boolean randomAccessRows; + private static final Logger log = LoggerFactory.getLogger(SparseRowMatrix.class); + /** * Construct a sparse matrix starting with the provided row vectors. * @@ -146,27 +150,39 @@ public Matrix assign(Matrix other, DoubleDoubleFunction function) { throw new CardinalityException(columns, other.columnSize()); } for (int row = 0; row < rows; row++) { - if( function.isLikeMult()) { // TODO: is this a sufficient test? - // TODO: this may cause an exception if the row type is not compatible but it is currently guaranteed to be - // a SequentialAccessSparseVector, should "try" here just in case and Warn - // TODO: can we use iterateNonZero on both rows until the index is the same to get better speedup? + try { Iterator sparseRowIterator = ((SequentialAccessSparseVector) this.rowVectors[row]) .iterateNonZero(); - // TODO: SASVs have an iterateNonZero that returns zeros, this should not hurt but is far from optimal - // this might perform much better if SparseRowMatrix were backed by RandomAccessSparseVectors, which - // are backed by fastutil hashmaps and the iterateNonZero does only return nonZeros. - while (sparseRowIterator.hasNext()) { - Vector.Element element = sparseRowIterator.next(); - int col = element.index(); - setQuick(row, col, function.apply(element.get(), other.getQuick(row, col))); + if (function.isLikeMult()) { // TODO: is this a sufficient test? + // TODO: this may cause an exception if the row type is not compatible but it is currently guaranteed to be + // a SequentialAccessSparseVector, should "try" here just in case and Warn + // TODO: can we use iterateNonZero on both rows until the index is the same to get better speedup? + + // TODO: SASVs have an iterateNonZero that returns zeros, this should not hurt but is far from optimal + // this might perform much better if SparseRowMatrix were backed by RandomAccessSparseVectors, which + // are backed by fastutil hashmaps and the iterateNonZero does only return nonZeros. + while (sparseRowIterator.hasNext()) { + Vector.Element element = sparseRowIterator.next(); + int col = element.index(); + setQuick(row, col, function.apply(element.get(), other.getQuick(row, col))); + } + } else { + for (int col = 0; col < columns; col++) { + setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col))); + } } - } else { + + } catch (ClassCastException e) { + // Warn and use default implementation + log.warn("Error casting the row to SequentialAccessSparseVector, this should never happen because" + + "SparseRomMatrix is always made of SequentialAccessSparseVectors. Proceeding with non-optimzed" + + "implementation."); for (int col = 0; col < columns; col++) { setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col))); } } + return this; } - return this; } @Override From 722bd11f01e7250f99f21f17ec7211bf5abb2089 Mon Sep 17 00:00:00 2001 From: pferrel Date: Mon, 2 Oct 2017 13:33:07 -0700 Subject: [PATCH 3/3] added cast exception logging to SparseRowMatrix --- math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java index 40d0c9e761..25e5acc975 100644 --- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java @@ -181,8 +181,8 @@ public Matrix assign(Matrix other, DoubleDoubleFunction function) { setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col))); } } - return this; } + return this; } @Override