From 227e1315e0739ab71c83c9889c67e616a997a34b Mon Sep 17 00:00:00 2001 From: Yury Babak Date: Mon, 14 Aug 2017 20:08:15 +0300 Subject: [PATCH 1/2] IGNITE-5925: Get row/col for matrices WIP --- .../org/apache/ignite/ml/math/Matrix.java | 16 ++++++++++++++ .../ml/math/impls/matrix/AbstractMatrix.java | 21 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java index 66de1a161a878..8c171a623a821 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java @@ -413,6 +413,14 @@ interface Element { */ public Matrix setRow(int row, double[] data); + /** + * Get a specific row from matrix. + * + * @param row Row index. + * @return row. + */ + public Vector getRow(int row); + /** * Sets values for given column. * @@ -424,6 +432,14 @@ interface Element { */ public Matrix setColumn(int col, double[] data); + /** + * Get a specific row from matrix. + * + * @param col Col index. + * @return Col. + */ + public Vector getCol(int col); + /** * Sets given value without checking for index bounds. This method is marginally faster * than its {@link #set(int, int, double)} sibling. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java index 2195a70f21e19..f8f35665d8437 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java @@ -41,6 +41,7 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteTriFunction; import org.apache.ignite.ml.math.functions.IntIntToDoubleFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.math.impls.vector.MatrixVectorView; /** @@ -708,6 +709,16 @@ protected Matrix likeIdentity() { return this; } + /** {@inheritDoc} */ + @Override public Vector getRow(int row) { + Vector res = new DenseLocalOnHeapVector(rowSize()); + + for (int i = 0; i < columnSize(); i++) + res.setX(i, getX(row,i)); + + return res; + } + /** {@inheritDoc} */ @Override public Matrix setColumn(int col, double[] data) { checkColumnIndex(col); @@ -723,6 +734,16 @@ protected Matrix likeIdentity() { return this; } + /** {@inheritDoc} */ + @Override public Vector getCol(int col) { + Vector res = new DenseLocalOnHeapVector(columnSize()); + + for (int i = 0; i < rowSize(); i++) + res.setX(i, getX(i,col)); + + return res; + } + /** {@inheritDoc} */ @Override public Matrix setX(int row, int col, double val) { storageSet(row, col, val); From 49e5846469d73e356a960e8a26fac82508fc10d1 Mon Sep 17 00:00:00 2001 From: Yury Babak Date: Fri, 18 Aug 2017 13:24:25 +0300 Subject: [PATCH 2/2] IGNITE-5925: Get row/col for matrices - implemented --- .../java/org/apache/ignite/ml/math/Blas.java | 4 ++-- .../ml/math/impls/matrix/AbstractMatrix.java | 8 ++++++-- .../impls/matrix/MatrixImplementationsTest.java | 17 +++++++++++++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java index a61d79650c03f..4b83ede12283e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java @@ -298,7 +298,7 @@ public static void gemv(double alpha, Matrix a, Vector x, double beta, Vector y) throw new CardinalityException(a.columnSize(), y.size()); checkMatrixType(a, "gemv"); - checkVectorType(x,"gemv"); + checkVectorType(x, "gemv"); checkVectorType(y, "gemv"); if (alpha == 0.0 && beta == 1.0) @@ -322,7 +322,7 @@ public static void gemv(double alpha, Matrix a, Vector x, double beta, Vector y) /** * M := alpha * M. * @param m Matrix M. - * @param alpha Aplha. + * @param alpha Alpha. */ private static void scal(Matrix m, double alpha) { if (alpha != 1.0) diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java index f8f35665d8437..06fb34c4139b1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java @@ -711,7 +711,9 @@ protected Matrix likeIdentity() { /** {@inheritDoc} */ @Override public Vector getRow(int row) { - Vector res = new DenseLocalOnHeapVector(rowSize()); + checkRowIndex(row); + + Vector res = new DenseLocalOnHeapVector(columnSize()); for (int i = 0; i < columnSize(); i++) res.setX(i, getX(row,i)); @@ -736,7 +738,9 @@ protected Matrix likeIdentity() { /** {@inheritDoc} */ @Override public Vector getCol(int col) { - Vector res = new DenseLocalOnHeapVector(columnSize()); + checkColumnIndex(col); + + Vector res = new DenseLocalOnHeapVector(rowSize()); for (int i = 0; i < rowSize(); i++) res.setX(i, getX(i,col)); diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java index 89b62248cc5a1..8270da149cb54 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java @@ -926,6 +926,23 @@ public void testAssignColumn() { }); } + /** */ + @Test + public void testGetRowCol(){ + consumeSampleMatrix((m,desc)-> { + if (! (m instanceof RandomMatrix)) + for (int i = 0; i < m.rowSize(); i++) + for (int j = 0; j < m.columnSize(); j++) + m.setX(i, j, i + j); + + for (int i = 0; i < m.rowSize(); i++) + assertNotNull("Unexpected value for " + desc + " at row " + i, m.getRow(i)); + + for (int i = 0; i < m.columnSize(); i++) + assertNotNull("Unexpected value for " + desc + " at col " + i, m.getCol(i)); + }); + } + /** */ private double[] fillArray(int len) { double[] newValues = new double[len];