-
Notifications
You must be signed in to change notification settings - Fork 13.8k
[FLINK-14153][ml] Add to BLAS a method that performs DenseMatrix and SparseVector multiplication. #9732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FLINK-14153][ml] Add to BLAS a method that performs DenseMatrix and SparseVector multiplication. #9732
Changes from all commits
6b34d9e
dbf3b11
74c52fb
c247440
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,26 +19,59 @@ | |
|
||
package org.apache.flink.ml.common.linalg; | ||
|
||
import org.apache.flink.util.Preconditions; | ||
|
||
/** | ||
* A utility class that provides BLAS routines over matrices and vectors. | ||
*/ | ||
public class BLAS { | ||
|
||
// For level-1 routines, we use Java implementation. | ||
private static final com.github.fommil.netlib.BLAS NATIVE_BLAS = com.github.fommil.netlib.BLAS.getInstance(); | ||
|
||
// For level-2 and level-3 routines, we use the native BLAS. | ||
// The NATIVE_BLAS instance tries to load BLAS implementations in the order: | ||
// 1) optimized system libraries such as Intel MKL, | ||
// 2) self-contained native builds using the reference Fortran from netlib.org, | ||
// 3) F2J implementation. | ||
// If to use optimized system libraries, it is important to turn of their multi-thread support. | ||
// Otherwise, it will conflict with Flink's executor and leads to performance loss. | ||
private static final com.github.fommil.netlib.BLAS F2J_BLAS = com.github.fommil.netlib.F2jBLAS.getInstance(); | ||
|
||
/** | ||
* \sum_i |x_i| . | ||
*/ | ||
public static double asum(int n, double[] x, int offset) { | ||
return F2J_BLAS.dasum(n, x, offset, 1); | ||
} | ||
|
||
/** | ||
* \sum_i |x_i| . | ||
*/ | ||
public static double asum(DenseVector x) { | ||
return asum(x.data.length, x.data, 0); | ||
} | ||
|
||
/** | ||
* \sum_i |x_i| . | ||
*/ | ||
public static double asum(SparseVector x) { | ||
return asum(x.values.length, x.values, 0); | ||
} | ||
|
||
/** | ||
* y += a * x . | ||
*/ | ||
public static void axpy(double a, double[] x, double[] y) { | ||
assert x.length == y.length : "Array dimension mismatched."; | ||
Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched."); | ||
F2J_BLAS.daxpy(x.length, a, x, 1, y, 1); | ||
} | ||
|
||
/** | ||
* y += a * x . | ||
*/ | ||
public static void axpy(double a, DenseVector x, DenseVector y) { | ||
assert x.data.length == y.data.length : "Vector dimension mismatched."; | ||
Preconditions.checkArgument(x.data.length == y.data.length, "Vector dimension mismatched."); | ||
F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1); | ||
} | ||
|
||
|
@@ -55,24 +88,34 @@ public static void axpy(double a, SparseVector x, DenseVector y) { | |
* y += a * x . | ||
*/ | ||
public static void axpy(double a, DenseMatrix x, DenseMatrix y) { | ||
assert x.m == y.m && x.n == y.n : "Matrix dimension mismatched."; | ||
Preconditions.checkArgument(x.m == y.m && x.n == y.n, "Matrix dimension mismatched."); | ||
F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1); | ||
} | ||
|
||
/** | ||
* y[yOffset:yOffset+n] += a * x[xOffset:xOffset+n] . | ||
*/ | ||
public static void axpy(int n, double a, double[] x, int xOffset, double[] y, int yOffset) { | ||
F2J_BLAS.daxpy(n, a, x, xOffset, 1, y, yOffset, 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A higher level question: I found a very interesting read from SPARK's mailing list and seems like there are some considerations regarding this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I have clarified in the inline doc that we should use F2J_BLAS for level-1 routines and NATIVE_BLAS for level-2 and level-3 routines. This is also the practices adopted by SparkML. The read from SPARK's mailing list you give here indeed shows the pitfalls when using NATIVE_BLAS. It makes clear that the underlying native BLAS libarary should not use multithreading. Fortunately, the default library uses a BLAS version provided by http://www.netlib.org, which is a single-threaded version. |
||
} | ||
|
||
/** | ||
* x \cdot y . | ||
*/ | ||
public static double dot(double[] x, double[] y) { | ||
assert x.length == y.length : "Array dimension mismatched."; | ||
return F2J_BLAS.ddot(x.length, x, 1, y, 1); | ||
Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched."); | ||
double s = 0.; | ||
for (int i = 0; i < x.length; i++) { | ||
s += x[i] * y[i]; | ||
} | ||
return s; | ||
} | ||
|
||
/** | ||
* x \cdot y . | ||
*/ | ||
public static double dot(DenseVector x, DenseVector y) { | ||
assert x.data.length == y.data.length : "Vector dimension mismatched."; | ||
return F2J_BLAS.ddot(x.data.length, x.data, 1, y.data, 1); | ||
return dot(x.getData(), y.getData()); | ||
} | ||
|
||
/** | ||
|
@@ -109,14 +152,18 @@ public static void scal(double a, DenseMatrix x) { | |
public static void gemm(double alpha, DenseMatrix matA, boolean transA, DenseMatrix matB, boolean transB, | ||
double beta, DenseMatrix matC) { | ||
if (transA) { | ||
assert matA.numCols() == matC.numRows() : "The columns of A does not match the rows of C"; | ||
Preconditions.checkArgument(matA.numCols() == matC.numRows(), | ||
"The columns of A does not match the rows of C"); | ||
} else { | ||
assert matA.numRows() == matC.numRows() : "The rows of A does not match the rows of C"; | ||
Preconditions.checkArgument(matA.numRows() == matC.numRows(), | ||
"The rows of A does not match the rows of C"); | ||
} | ||
if (transB) { | ||
assert matB.numRows() == matC.numCols() : "The rows of B does not match the columns of C"; | ||
Preconditions.checkArgument(matB.numRows() == matC.numCols(), | ||
"The rows of B does not match the columns of C"); | ||
} else { | ||
assert matB.numCols() == matC.numCols() : "The columns of B does not match the columns of C"; | ||
Preconditions.checkArgument(matB.numCols() == matC.numCols(), | ||
"The columns of B does not match the columns of C"); | ||
} | ||
|
||
final int m = matC.numRows(); | ||
|
@@ -131,19 +178,56 @@ public static void gemm(double alpha, DenseMatrix matA, boolean transA, DenseMat | |
} | ||
|
||
/** | ||
* y := alpha * A * x + beta * y . | ||
* Check the compatibility of matrix and vector sizes in <code>gemv</code>. | ||
*/ | ||
public static void gemv(double alpha, DenseMatrix matA, boolean transA, | ||
DenseVector x, double beta, DenseVector y) { | ||
private static void gemvDimensionCheck(DenseMatrix matA, boolean transA, Vector x, Vector y) { | ||
if (transA) { | ||
assert (matA.numCols() == y.size() && matA.numRows() == x.size()) : "Matrix and vector size mismatched."; | ||
Preconditions.checkArgument(matA.numCols() == y.size() && matA.numRows() == x.size(), | ||
"Matrix and vector size mismatched."); | ||
} else { | ||
assert (matA.numRows() == y.size() && matA.numCols() == x.size()) : "Matrix and vector size mismatched."; | ||
Preconditions.checkArgument(matA.numRows() == y.size() && matA.numCols() == x.size(), | ||
"Matrix and vector size mismatched."); | ||
} | ||
} | ||
|
||
/** | ||
* y := alpha * A * x + beta * y . | ||
*/ | ||
public static void gemv(double alpha, DenseMatrix matA, boolean transA, | ||
DenseVector x, double beta, DenseVector y) { | ||
gemvDimensionCheck(matA, transA, x, y); | ||
final int m = matA.numRows(); | ||
final int n = matA.numCols(); | ||
final int lda = matA.numRows(); | ||
final String ta = transA ? "T" : "N"; | ||
NATIVE_BLAS.dgemv(ta, m, n, alpha, matA.getData(), lda, x.getData(), 1, beta, y.getData(), 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason why I think at least on a specific level (0,1,2,3 or up) we should probably only use one specific BLAS version unless specific reason comes up (IMO it should be some very strong justifications) FYI: I am not sure whether this is related. Some suggestions on stack shows that there are some performance considerations coming from latest development from the JIT compiler There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We consistently use F2J_BLAS for level-1 routines such as scal/axpy/asum, and use NATIVE_BLAS for level-2/level-3 routines such as gemv, gemm. As for the gemv case here, we use NATIVE_BLAS for the dense case. But for the sparse case, the BLAS library is not directly applicable because it is a library for dense linear algebra. So we implement gemv for SparseVector by hand, using F2J_BLAS to do axpy(level-1 routine) during the course. |
||
} | ||
|
||
/** | ||
* y := alpha * A * x + beta * y . | ||
*/ | ||
public static void gemv(double alpha, DenseMatrix matA, boolean transA, | ||
SparseVector x, double beta, DenseVector y) { | ||
gemvDimensionCheck(matA, transA, x, y); | ||
final int m = matA.numRows(); | ||
final int n = matA.numCols(); | ||
if (transA) { | ||
int start = 0; | ||
for (int i = 0; i < n; i++) { | ||
double s = 0.; | ||
for (int j = 0; j < x.indices.length; j++) { | ||
s += x.values[j] * matA.data[start + x.indices[j]]; | ||
} | ||
y.data[i] = beta * y.data[i] + alpha * s; | ||
start += m; | ||
} | ||
} else { | ||
scal(beta, y); | ||
for (int i = 0; i < x.indices.length; i++) { | ||
int index = x.indices[i]; | ||
double value = alpha * x.values[i]; | ||
F2J_BLAS.daxpy(m, value, matA.data, index * m, 1, y.data, 0, 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not familiar with BLAS internal performance. is this faster than directly coding it up ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two reasons I use BLAS here.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that makes sense. I am thinking that the assumption is for example BLAS can do some sort of SIMD/MIMD optimization based on the data locality so that it can save register/cache loadings and invalidations.
If there's any performance issue later, we can always avoid duplicate register loading by doing the multiplication the addition and variable assignment at the same line similar to
Comment on lines
+228
to
+229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add the explanation you added in this PR comment into the actual code comments? I think it helps others to understand this code in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, we did it. |
||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.common.linalg; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
|
||
/** | ||
* The test cases for {@link BLAS}. | ||
*/ | ||
public class BLASTest { | ||
private static final double TOL = 1.0e-8; | ||
private DenseMatrix mat = new DenseMatrix(2, 3, new double[]{1, 4, 2, 5, 3, 6}); | ||
private DenseVector dv1 = new DenseVector(new double[]{1, 2}); | ||
private DenseVector dv2 = new DenseVector(new double[]{1, 2, 3}); | ||
private SparseVector spv1 = new SparseVector(2, new int[]{0, 1}, new double[]{1, 2}); | ||
private SparseVector spv2 = new SparseVector(3, new int[]{0, 2}, new double[]{1, 3}); | ||
|
||
@Rule | ||
public ExpectedException thrown = ExpectedException.none(); | ||
|
||
@Test | ||
public void testAsum() throws Exception { | ||
Assert.assertEquals(BLAS.asum(dv1), 3.0, TOL); | ||
Assert.assertEquals(BLAS.asum(spv1), 3.0, TOL); | ||
} | ||
|
||
@Test | ||
public void testScal() throws Exception { | ||
DenseVector v1 = dv1.clone(); | ||
BLAS.scal(0.5, v1); | ||
Assert.assertArrayEquals(v1.getData(), new double[]{0.5, 1.0}, TOL); | ||
|
||
SparseVector v2 = spv1.clone(); | ||
BLAS.scal(0.5, v2); | ||
Assert.assertArrayEquals(v2.getIndices(), spv1.getIndices()); | ||
Assert.assertArrayEquals(v2.getValues(), new double[]{0.5, 1.0}, TOL); | ||
} | ||
|
||
@Test | ||
public void testDot() throws Exception { | ||
DenseVector v = DenseVector.ones(2); | ||
Assert.assertEquals(BLAS.dot(dv1, v), 3.0, TOL); | ||
} | ||
|
||
@Test | ||
public void testAxpy() throws Exception { | ||
DenseVector v = DenseVector.ones(2); | ||
BLAS.axpy(1.0, dv1, v); | ||
Assert.assertArrayEquals(v.getData(), new double[]{2, 3}, TOL); | ||
BLAS.axpy(1.0, spv1, v); | ||
Assert.assertArrayEquals(v.getData(), new double[]{3, 5}, TOL); | ||
BLAS.axpy(1, 1.0, new double[]{1}, 0, v.getData(), 1); | ||
Assert.assertArrayEquals(v.getData(), new double[]{3, 6}, TOL); | ||
} | ||
|
||
private DenseMatrix simpleMM(DenseMatrix m1, DenseMatrix m2) { | ||
DenseMatrix mm = new DenseMatrix(m1.numRows(), m2.numCols()); | ||
for (int i = 0; i < m1.numRows(); i++) { | ||
for (int j = 0; j < m2.numCols(); j++) { | ||
double s = 0.; | ||
for (int k = 0; k < m1.numCols(); k++) { | ||
s += m1.get(i, k) * m2.get(k, j); | ||
} | ||
mm.set(i, j, s); | ||
} | ||
} | ||
return mm; | ||
} | ||
|
||
@Test | ||
public void testGemm() throws Exception { | ||
DenseMatrix m32 = DenseMatrix.rand(3, 2); | ||
DenseMatrix m24 = DenseMatrix.rand(2, 4); | ||
DenseMatrix m34 = DenseMatrix.rand(3, 4); | ||
DenseMatrix m42 = DenseMatrix.rand(4, 2); | ||
DenseMatrix m43 = DenseMatrix.rand(4, 3); | ||
|
||
DenseMatrix a34 = DenseMatrix.zeros(3, 4); | ||
BLAS.gemm(1.0, m32, false, m24, false, 0., a34); | ||
Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m24).getData(), TOL); | ||
|
||
BLAS.gemm(1.0, m32, false, m42, true, 0., a34); | ||
Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m42.transpose()).getData(), TOL); | ||
|
||
DenseMatrix a24 = DenseMatrix.zeros(2, 4); | ||
BLAS.gemm(1.0, m32, true, m34, false, 0., a24); | ||
Assert.assertArrayEquals(a24.getData(), simpleMM(m32.transpose(), m34).getData(), TOL); | ||
|
||
BLAS.gemm(1.0, m32, true, m43, true, 0., a24); | ||
Assert.assertArrayEquals(a24.getData(), simpleMM(m32.transpose(), m43.transpose()).getData(), TOL); | ||
} | ||
|
||
@Test | ||
public void testGemmSizeCheck() throws Exception { | ||
thrown.expect(IllegalArgumentException.class); | ||
DenseMatrix m32 = DenseMatrix.rand(3, 2); | ||
DenseMatrix m42 = DenseMatrix.rand(4, 2); | ||
DenseMatrix a34 = DenseMatrix.zeros(3, 4); | ||
BLAS.gemm(1.0, m32, false, m42, false, 0., a34); | ||
} | ||
|
||
@Test | ||
public void testGemmTransposeSizeCheck() throws Exception { | ||
thrown.expect(IllegalArgumentException.class); | ||
DenseMatrix m32 = DenseMatrix.rand(3, 2); | ||
DenseMatrix m42 = DenseMatrix.rand(4, 2); | ||
DenseMatrix a34 = DenseMatrix.zeros(3, 4); | ||
BLAS.gemm(1.0, m32, true, m42, true, 0., a34); | ||
} | ||
|
||
@Test | ||
public void testGemvDense() throws Exception { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the dense version not tested until now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We just added more test cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing validator exception test cases:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, we just added the test cases. |
||
DenseVector y1 = DenseVector.ones(2); | ||
BLAS.gemv(2.0, mat, false, dv2, 0., y1); | ||
Assert.assertArrayEquals(new double[]{28, 64}, y1.data, TOL); | ||
|
||
DenseVector y2 = DenseVector.ones(2); | ||
BLAS.gemv(2.0, mat, false, dv2, 1., y2); | ||
Assert.assertArrayEquals(new double[]{29, 65}, y2.data, TOL); | ||
} | ||
|
||
@Test | ||
public void testGemvDenseTranspose() throws Exception { | ||
DenseVector y1 = DenseVector.ones(3); | ||
BLAS.gemv(1.0, mat, true, dv1, 0., y1); | ||
Assert.assertArrayEquals(new double[]{9, 12, 15}, y1.data, TOL); | ||
|
||
DenseVector y2 = DenseVector.ones(3); | ||
BLAS.gemv(1.0, mat, true, dv1, 1., y2); | ||
Assert.assertArrayEquals(new double[]{10, 13, 16}, y2.data, TOL); | ||
} | ||
|
||
@Test | ||
public void testGemvSparse() throws Exception { | ||
DenseVector y1 = DenseVector.ones(2); | ||
BLAS.gemv(2.0, mat, false, spv2, 0., y1); | ||
Assert.assertArrayEquals(new double[]{20, 44}, y1.data, TOL); | ||
|
||
DenseVector y2 = DenseVector.ones(2); | ||
BLAS.gemv(2.0, mat, false, spv2, 1., y2); | ||
Assert.assertArrayEquals(new double[]{21, 45}, y2.data, TOL); | ||
} | ||
|
||
@Test | ||
public void testGemvSparseTranspose() throws Exception { | ||
DenseVector y1 = DenseVector.ones(3); | ||
BLAS.gemv(2.0, mat, true, spv1, 0., y1); | ||
Assert.assertArrayEquals(new double[]{18, 24, 30}, y1.data, TOL); | ||
|
||
DenseVector y2 = DenseVector.ones(3); | ||
BLAS.gemv(2.0, mat, true, spv1, 1., y2); | ||
Assert.assertArrayEquals(new double[]{19, 25, 31}, y2.data, TOL); | ||
} | ||
|
||
@Test | ||
public void testGemvSizeCheck() throws Exception { | ||
thrown.expect(IllegalArgumentException.class); | ||
DenseVector y = DenseVector.ones(2); | ||
BLAS.gemv(2.0, mat, false, dv1, 0., y); | ||
} | ||
|
||
@Test | ||
public void testGemvTransposeSizeCheck() throws Exception { | ||
thrown.expect(IllegalArgumentException.class); | ||
DenseVector y = DenseVector.ones(2); | ||
BLAS.gemv(2.0, mat, true, dv1, 0., y); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use
/* */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please refine it when merging. Thanks.