diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java index f5d9e8497e8e6..b1d0ba64cb77e 100644 --- a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java @@ -19,18 +19,51 @@ package org.apache.flink.ml.common.linalg; +import org.apache.flink.util.Preconditions; + /** * A utility class that provides BLAS routines over matrices and vectors. */ public class BLAS { + + // For level-1 routines, we use Java implementation. private static final com.github.fommil.netlib.BLAS NATIVE_BLAS = com.github.fommil.netlib.BLAS.getInstance(); + + // For level-2 and level-3 routines, we use the native BLAS. + // The NATIVE_BLAS instance tries to load BLAS implementations in the order: + // 1) optimized system libraries such as Intel MKL, + // 2) self-contained native builds using the reference Fortran from netlib.org, + // 3) F2J implementation. + // If to use optimized system libraries, it is important to turn of their multi-thread support. + // Otherwise, it will conflict with Flink's executor and leads to performance loss. private static final com.github.fommil.netlib.BLAS F2J_BLAS = com.github.fommil.netlib.F2jBLAS.getInstance(); + /** + * \sum_i |x_i| . + */ + public static double asum(int n, double[] x, int offset) { + return F2J_BLAS.dasum(n, x, offset, 1); + } + + /** + * \sum_i |x_i| . + */ + public static double asum(DenseVector x) { + return asum(x.data.length, x.data, 0); + } + + /** + * \sum_i |x_i| . + */ + public static double asum(SparseVector x) { + return asum(x.values.length, x.values, 0); + } + /** * y += a * x . */ public static void axpy(double a, double[] x, double[] y) { - assert x.length == y.length : "Array dimension mismatched."; + Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched."); F2J_BLAS.daxpy(x.length, a, x, 1, y, 1); } @@ -38,7 +71,7 @@ public static void axpy(double a, double[] x, double[] y) { * y += a * x . */ public static void axpy(double a, DenseVector x, DenseVector y) { - assert x.data.length == y.data.length : "Vector dimension mismatched."; + Preconditions.checkArgument(x.data.length == y.data.length, "Vector dimension mismatched."); F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1); } @@ -55,24 +88,34 @@ public static void axpy(double a, SparseVector x, DenseVector y) { * y += a * x . */ public static void axpy(double a, DenseMatrix x, DenseMatrix y) { - assert x.m == y.m && x.n == y.n : "Matrix dimension mismatched."; + Preconditions.checkArgument(x.m == y.m && x.n == y.n, "Matrix dimension mismatched."); F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1); } + /** + * y[yOffset:yOffset+n] += a * x[xOffset:xOffset+n] . + */ + public static void axpy(int n, double a, double[] x, int xOffset, double[] y, int yOffset) { + F2J_BLAS.daxpy(n, a, x, xOffset, 1, y, yOffset, 1); + } + /** * x \cdot y . */ public static double dot(double[] x, double[] y) { - assert x.length == y.length : "Array dimension mismatched."; - return F2J_BLAS.ddot(x.length, x, 1, y, 1); + Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched."); + double s = 0.; + for (int i = 0; i < x.length; i++) { + s += x[i] * y[i]; + } + return s; } /** * x \cdot y . */ public static double dot(DenseVector x, DenseVector y) { - assert x.data.length == y.data.length : "Vector dimension mismatched."; - return F2J_BLAS.ddot(x.data.length, x.data, 1, y.data, 1); + return dot(x.getData(), y.getData()); } /** @@ -109,14 +152,18 @@ public static void scal(double a, DenseMatrix x) { public static void gemm(double alpha, DenseMatrix matA, boolean transA, DenseMatrix matB, boolean transB, double beta, DenseMatrix matC) { if (transA) { - assert matA.numCols() == matC.numRows() : "The columns of A does not match the rows of C"; + Preconditions.checkArgument(matA.numCols() == matC.numRows(), + "The columns of A does not match the rows of C"); } else { - assert matA.numRows() == matC.numRows() : "The rows of A does not match the rows of C"; + Preconditions.checkArgument(matA.numRows() == matC.numRows(), + "The rows of A does not match the rows of C"); } if (transB) { - assert matB.numRows() == matC.numCols() : "The rows of B does not match the columns of C"; + Preconditions.checkArgument(matB.numRows() == matC.numCols(), + "The rows of B does not match the columns of C"); } else { - assert matB.numCols() == matC.numCols() : "The columns of B does not match the columns of C"; + Preconditions.checkArgument(matB.numCols() == matC.numCols(), + "The columns of B does not match the columns of C"); } final int m = matC.numRows(); @@ -131,19 +178,56 @@ public static void gemm(double alpha, DenseMatrix matA, boolean transA, DenseMat } /** - * y := alpha * A * x + beta * y . + * Check the compatibility of matrix and vector sizes in gemv. */ - public static void gemv(double alpha, DenseMatrix matA, boolean transA, - DenseVector x, double beta, DenseVector y) { + private static void gemvDimensionCheck(DenseMatrix matA, boolean transA, Vector x, Vector y) { if (transA) { - assert (matA.numCols() == y.size() && matA.numRows() == x.size()) : "Matrix and vector size mismatched."; + Preconditions.checkArgument(matA.numCols() == y.size() && matA.numRows() == x.size(), + "Matrix and vector size mismatched."); } else { - assert (matA.numRows() == y.size() && matA.numCols() == x.size()) : "Matrix and vector size mismatched."; + Preconditions.checkArgument(matA.numRows() == y.size() && matA.numCols() == x.size(), + "Matrix and vector size mismatched."); } + } + + /** + * y := alpha * A * x + beta * y . + */ + public static void gemv(double alpha, DenseMatrix matA, boolean transA, + DenseVector x, double beta, DenseVector y) { + gemvDimensionCheck(matA, transA, x, y); final int m = matA.numRows(); final int n = matA.numCols(); final int lda = matA.numRows(); final String ta = transA ? "T" : "N"; NATIVE_BLAS.dgemv(ta, m, n, alpha, matA.getData(), lda, x.getData(), 1, beta, y.getData(), 1); } + + /** + * y := alpha * A * x + beta * y . + */ + public static void gemv(double alpha, DenseMatrix matA, boolean transA, + SparseVector x, double beta, DenseVector y) { + gemvDimensionCheck(matA, transA, x, y); + final int m = matA.numRows(); + final int n = matA.numCols(); + if (transA) { + int start = 0; + for (int i = 0; i < n; i++) { + double s = 0.; + for (int j = 0; j < x.indices.length; j++) { + s += x.values[j] * matA.data[start + x.indices[j]]; + } + y.data[i] = beta * y.data[i] + alpha * s; + start += m; + } + } else { + scal(beta, y); + for (int i = 0; i < x.indices.length; i++) { + int index = x.indices[i]; + double value = alpha * x.values[i]; + F2J_BLAS.daxpy(m, value, matA.data, index * m, 1, y.data, 0, 1); + } + } + } } diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/BLASTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/BLASTest.java new file mode 100644 index 0000000000000..c30b0dfaee57b --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/BLASTest.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** + * The test cases for {@link BLAS}. + */ +public class BLASTest { + private static final double TOL = 1.0e-8; + private DenseMatrix mat = new DenseMatrix(2, 3, new double[]{1, 4, 2, 5, 3, 6}); + private DenseVector dv1 = new DenseVector(new double[]{1, 2}); + private DenseVector dv2 = new DenseVector(new double[]{1, 2, 3}); + private SparseVector spv1 = new SparseVector(2, new int[]{0, 1}, new double[]{1, 2}); + private SparseVector spv2 = new SparseVector(3, new int[]{0, 2}, new double[]{1, 3}); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testAsum() throws Exception { + Assert.assertEquals(BLAS.asum(dv1), 3.0, TOL); + Assert.assertEquals(BLAS.asum(spv1), 3.0, TOL); + } + + @Test + public void testScal() throws Exception { + DenseVector v1 = dv1.clone(); + BLAS.scal(0.5, v1); + Assert.assertArrayEquals(v1.getData(), new double[]{0.5, 1.0}, TOL); + + SparseVector v2 = spv1.clone(); + BLAS.scal(0.5, v2); + Assert.assertArrayEquals(v2.getIndices(), spv1.getIndices()); + Assert.assertArrayEquals(v2.getValues(), new double[]{0.5, 1.0}, TOL); + } + + @Test + public void testDot() throws Exception { + DenseVector v = DenseVector.ones(2); + Assert.assertEquals(BLAS.dot(dv1, v), 3.0, TOL); + } + + @Test + public void testAxpy() throws Exception { + DenseVector v = DenseVector.ones(2); + BLAS.axpy(1.0, dv1, v); + Assert.assertArrayEquals(v.getData(), new double[]{2, 3}, TOL); + BLAS.axpy(1.0, spv1, v); + Assert.assertArrayEquals(v.getData(), new double[]{3, 5}, TOL); + BLAS.axpy(1, 1.0, new double[]{1}, 0, v.getData(), 1); + Assert.assertArrayEquals(v.getData(), new double[]{3, 6}, TOL); + } + + private DenseMatrix simpleMM(DenseMatrix m1, DenseMatrix m2) { + DenseMatrix mm = new DenseMatrix(m1.numRows(), m2.numCols()); + for (int i = 0; i < m1.numRows(); i++) { + for (int j = 0; j < m2.numCols(); j++) { + double s = 0.; + for (int k = 0; k < m1.numCols(); k++) { + s += m1.get(i, k) * m2.get(k, j); + } + mm.set(i, j, s); + } + } + return mm; + } + + @Test + public void testGemm() throws Exception { + DenseMatrix m32 = DenseMatrix.rand(3, 2); + DenseMatrix m24 = DenseMatrix.rand(2, 4); + DenseMatrix m34 = DenseMatrix.rand(3, 4); + DenseMatrix m42 = DenseMatrix.rand(4, 2); + DenseMatrix m43 = DenseMatrix.rand(4, 3); + + DenseMatrix a34 = DenseMatrix.zeros(3, 4); + BLAS.gemm(1.0, m32, false, m24, false, 0., a34); + Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m24).getData(), TOL); + + BLAS.gemm(1.0, m32, false, m42, true, 0., a34); + Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m42.transpose()).getData(), TOL); + + DenseMatrix a24 = DenseMatrix.zeros(2, 4); + BLAS.gemm(1.0, m32, true, m34, false, 0., a24); + Assert.assertArrayEquals(a24.getData(), simpleMM(m32.transpose(), m34).getData(), TOL); + + BLAS.gemm(1.0, m32, true, m43, true, 0., a24); + Assert.assertArrayEquals(a24.getData(), simpleMM(m32.transpose(), m43.transpose()).getData(), TOL); + } + + @Test + public void testGemmSizeCheck() throws Exception { + thrown.expect(IllegalArgumentException.class); + DenseMatrix m32 = DenseMatrix.rand(3, 2); + DenseMatrix m42 = DenseMatrix.rand(4, 2); + DenseMatrix a34 = DenseMatrix.zeros(3, 4); + BLAS.gemm(1.0, m32, false, m42, false, 0., a34); + } + + @Test + public void testGemmTransposeSizeCheck() throws Exception { + thrown.expect(IllegalArgumentException.class); + DenseMatrix m32 = DenseMatrix.rand(3, 2); + DenseMatrix m42 = DenseMatrix.rand(4, 2); + DenseMatrix a34 = DenseMatrix.zeros(3, 4); + BLAS.gemm(1.0, m32, true, m42, true, 0., a34); + } + + @Test + public void testGemvDense() throws Exception { + DenseVector y1 = DenseVector.ones(2); + BLAS.gemv(2.0, mat, false, dv2, 0., y1); + Assert.assertArrayEquals(new double[]{28, 64}, y1.data, TOL); + + DenseVector y2 = DenseVector.ones(2); + BLAS.gemv(2.0, mat, false, dv2, 1., y2); + Assert.assertArrayEquals(new double[]{29, 65}, y2.data, TOL); + } + + @Test + public void testGemvDenseTranspose() throws Exception { + DenseVector y1 = DenseVector.ones(3); + BLAS.gemv(1.0, mat, true, dv1, 0., y1); + Assert.assertArrayEquals(new double[]{9, 12, 15}, y1.data, TOL); + + DenseVector y2 = DenseVector.ones(3); + BLAS.gemv(1.0, mat, true, dv1, 1., y2); + Assert.assertArrayEquals(new double[]{10, 13, 16}, y2.data, TOL); + } + + @Test + public void testGemvSparse() throws Exception { + DenseVector y1 = DenseVector.ones(2); + BLAS.gemv(2.0, mat, false, spv2, 0., y1); + Assert.assertArrayEquals(new double[]{20, 44}, y1.data, TOL); + + DenseVector y2 = DenseVector.ones(2); + BLAS.gemv(2.0, mat, false, spv2, 1., y2); + Assert.assertArrayEquals(new double[]{21, 45}, y2.data, TOL); + } + + @Test + public void testGemvSparseTranspose() throws Exception { + DenseVector y1 = DenseVector.ones(3); + BLAS.gemv(2.0, mat, true, spv1, 0., y1); + Assert.assertArrayEquals(new double[]{18, 24, 30}, y1.data, TOL); + + DenseVector y2 = DenseVector.ones(3); + BLAS.gemv(2.0, mat, true, spv1, 1., y2); + Assert.assertArrayEquals(new double[]{19, 25, 31}, y2.data, TOL); + } + + @Test + public void testGemvSizeCheck() throws Exception { + thrown.expect(IllegalArgumentException.class); + DenseVector y = DenseVector.ones(2); + BLAS.gemv(2.0, mat, false, dv1, 0., y); + } + + @Test + public void testGemvTransposeSizeCheck() throws Exception { + thrown.expect(IllegalArgumentException.class); + DenseVector y = DenseVector.ones(2); + BLAS.gemv(2.0, mat, true, dv1, 0., y); + } +}