Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,59 @@

package org.apache.flink.ml.common.linalg;

import org.apache.flink.util.Preconditions;

/**
* A utility class that provides BLAS routines over matrices and vectors.
*/
public class BLAS {

// For level-1 routines, we use Java implementation.
private static final com.github.fommil.netlib.BLAS NATIVE_BLAS = com.github.fommil.netlib.BLAS.getInstance();

// For level-2 and level-3 routines, we use the native BLAS.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use /* */

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please refine it when merging. Thanks.

// The NATIVE_BLAS instance tries to load BLAS implementations in the order:
// 1) optimized system libraries such as Intel MKL,
// 2) self-contained native builds using the reference Fortran from netlib.org,
// 3) F2J implementation.
// If to use optimized system libraries, it is important to turn of their multi-thread support.
// Otherwise, it will conflict with Flink's executor and leads to performance loss.
private static final com.github.fommil.netlib.BLAS F2J_BLAS = com.github.fommil.netlib.F2jBLAS.getInstance();

/**
* \sum_i |x_i| .
*/
public static double asum(int n, double[] x, int offset) {
return F2J_BLAS.dasum(n, x, offset, 1);
}

/**
* \sum_i |x_i| .
*/
public static double asum(DenseVector x) {
return asum(x.data.length, x.data, 0);
}

/**
* \sum_i |x_i| .
*/
public static double asum(SparseVector x) {
return asum(x.values.length, x.values, 0);
}

/**
* y += a * x .
*/
public static void axpy(double a, double[] x, double[] y) {
assert x.length == y.length : "Array dimension mismatched.";
Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched.");
F2J_BLAS.daxpy(x.length, a, x, 1, y, 1);
}

/**
* y += a * x .
*/
public static void axpy(double a, DenseVector x, DenseVector y) {
assert x.data.length == y.data.length : "Vector dimension mismatched.";
Preconditions.checkArgument(x.data.length == y.data.length, "Vector dimension mismatched.");
F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1);
}

Expand All @@ -55,24 +88,34 @@ public static void axpy(double a, SparseVector x, DenseVector y) {
* y += a * x .
*/
public static void axpy(double a, DenseMatrix x, DenseMatrix y) {
assert x.m == y.m && x.n == y.n : "Matrix dimension mismatched.";
Preconditions.checkArgument(x.m == y.m && x.n == y.n, "Matrix dimension mismatched.");
F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1);
}

/**
* y[yOffset:yOffset+n] += a * x[xOffset:xOffset+n] .
*/
public static void axpy(int n, double a, double[] x, int xOffset, double[] y, int yOffset) {
F2J_BLAS.daxpy(n, a, x, xOffset, 1, y, yOffset, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A higher level question:
Maybe it is good to clarify when to use F2J_BLAS vs NATIVE_BLAS.

I found a very interesting read from SPARK's mailing list and seems like there are some considerations regarding this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have clarified in the inline doc that we should use F2J_BLAS for level-1 routines and NATIVE_BLAS for level-2 and level-3 routines. This is also the practices adopted by SparkML.

The read from SPARK's mailing list you give here indeed shows the pitfalls when using NATIVE_BLAS. It makes clear that the underlying native BLAS libarary should not use multithreading. Fortunately, the default library uses a BLAS version provided by http://www.netlib.org, which is a single-threaded version.

}

/**
* x \cdot y .
*/
public static double dot(double[] x, double[] y) {
assert x.length == y.length : "Array dimension mismatched.";
return F2J_BLAS.ddot(x.length, x, 1, y, 1);
Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched.");
double s = 0.;
for (int i = 0; i < x.length; i++) {
s += x[i] * y[i];
}
return s;
}

/**
* x \cdot y .
*/
public static double dot(DenseVector x, DenseVector y) {
assert x.data.length == y.data.length : "Vector dimension mismatched.";
return F2J_BLAS.ddot(x.data.length, x.data, 1, y.data, 1);
return dot(x.getData(), y.getData());
}

/**
Expand Down Expand Up @@ -109,14 +152,18 @@ public static void scal(double a, DenseMatrix x) {
public static void gemm(double alpha, DenseMatrix matA, boolean transA, DenseMatrix matB, boolean transB,
double beta, DenseMatrix matC) {
if (transA) {
assert matA.numCols() == matC.numRows() : "The columns of A does not match the rows of C";
Preconditions.checkArgument(matA.numCols() == matC.numRows(),
"The columns of A does not match the rows of C");
} else {
assert matA.numRows() == matC.numRows() : "The rows of A does not match the rows of C";
Preconditions.checkArgument(matA.numRows() == matC.numRows(),
"The rows of A does not match the rows of C");
}
if (transB) {
assert matB.numRows() == matC.numCols() : "The rows of B does not match the columns of C";
Preconditions.checkArgument(matB.numRows() == matC.numCols(),
"The rows of B does not match the columns of C");
} else {
assert matB.numCols() == matC.numCols() : "The columns of B does not match the columns of C";
Preconditions.checkArgument(matB.numCols() == matC.numCols(),
"The columns of B does not match the columns of C");
}

final int m = matC.numRows();
Expand All @@ -131,19 +178,56 @@ public static void gemm(double alpha, DenseMatrix matA, boolean transA, DenseMat
}

/**
* y := alpha * A * x + beta * y .
* Check the compatibility of matrix and vector sizes in <code>gemv</code>.
*/
public static void gemv(double alpha, DenseMatrix matA, boolean transA,
DenseVector x, double beta, DenseVector y) {
private static void gemvDimensionCheck(DenseMatrix matA, boolean transA, Vector x, Vector y) {
if (transA) {
assert (matA.numCols() == y.size() && matA.numRows() == x.size()) : "Matrix and vector size mismatched.";
Preconditions.checkArgument(matA.numCols() == y.size() && matA.numRows() == x.size(),
"Matrix and vector size mismatched.");
} else {
assert (matA.numRows() == y.size() && matA.numCols() == x.size()) : "Matrix and vector size mismatched.";
Preconditions.checkArgument(matA.numRows() == y.size() && matA.numCols() == x.size(),
"Matrix and vector size mismatched.");
}
}

/**
* y := alpha * A * x + beta * y .
*/
public static void gemv(double alpha, DenseMatrix matA, boolean transA,
DenseVector x, double beta, DenseVector y) {
gemvDimensionCheck(matA, transA, x, y);
final int m = matA.numRows();
final int n = matA.numCols();
final int lda = matA.numRows();
final String ta = transA ? "T" : "N";
NATIVE_BLAS.dgemv(ta, m, n, alpha, matA.getData(), lda, x.getData(), 1, beta, y.getData(), 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why DenseVector uses NATIVE_BLAS while the SparseVector uses F2J_BLAS

I think at least on a specific level (0,1,2,3 or up) we should probably only use one specific BLAS version unless specific reason comes up (IMO it should be some very strong justifications)

FYI: I am not sure whether this is related. Some suggestions on stack shows that there are some performance considerations coming from latest development from the JIT compiler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We consistently use F2J_BLAS for level-1 routines such as scal/axpy/asum, and use NATIVE_BLAS for level-2/level-3 routines such as gemv, gemm. As for the gemv case here, we use NATIVE_BLAS for the dense case. But for the sparse case, the BLAS library is not directly applicable because it is a library for dense linear algebra. So we implement gemv for SparseVector by hand, using F2J_BLAS to do axpy(level-1 routine) during the course.

}

/**
* y := alpha * A * x + beta * y .
*/
public static void gemv(double alpha, DenseMatrix matA, boolean transA,
SparseVector x, double beta, DenseVector y) {
gemvDimensionCheck(matA, transA, x, y);
final int m = matA.numRows();
final int n = matA.numCols();
if (transA) {
int start = 0;
for (int i = 0; i < n; i++) {
double s = 0.;
for (int j = 0; j < x.indices.length; j++) {
s += x.values[j] * matA.data[start + x.indices[j]];
}
y.data[i] = beta * y.data[i] + alpha * s;
start += m;
}
} else {
scal(beta, y);
for (int i = 0; i < x.indices.length; i++) {
int index = x.indices[i];
double value = alpha * x.values[i];
F2J_BLAS.daxpy(m, value, matA.data, index * m, 1, y.data, 0, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with BLAS internal performance. is this faster than directly coding it up ?
the reason why I am asking is that: there's two step involved. (scal and daxpy). may have duplicate mem access ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two reasons I use BLAS here.

  1. BLAS is a mature linear algebra libarary, which pays a lot of attension to data locality, thus it is more cache friendly than naive implementation. Usually we gain a lot of permformace improvement on level-2/level-3 BLAS routine through calling native (JNI) BLAS, while F2J BLAS is better in level-1 BLAS routines.
  2. In the case here, y is first scaled by b, then each columns of Ax is added to y. It is inevitable that y would be visited more than once.

Copy link
Contributor

@walterddr walterddr Oct 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense. I am thinking that the assumption is for example BLAS can do some sort of SIMD/MIMD optimization based on the data locality so that it can save register/cache loadings and invalidations.
It is good to call out the rationale here by having some sort of inline comment:

// relying on the native implementation of BLAS  for performance.

If there's any performance issue later, we can always avoid duplicate register loading by doing the multiplication the addition and variable assignment at the same line similar to

y.data[i] = beta * y.data[i] + alpha * s;

Comment on lines +228 to +229
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the explanation you added in this PR comment into the actual code comments? I think it helps others to understand this code in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we did it.

}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.linalg;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

/**
* The test cases for {@link BLAS}.
*/
public class BLASTest {
private static final double TOL = 1.0e-8;
private DenseMatrix mat = new DenseMatrix(2, 3, new double[]{1, 4, 2, 5, 3, 6});
private DenseVector dv1 = new DenseVector(new double[]{1, 2});
private DenseVector dv2 = new DenseVector(new double[]{1, 2, 3});
private SparseVector spv1 = new SparseVector(2, new int[]{0, 1}, new double[]{1, 2});
private SparseVector spv2 = new SparseVector(3, new int[]{0, 2}, new double[]{1, 3});

@Rule
public ExpectedException thrown = ExpectedException.none();

@Test
public void testAsum() throws Exception {
Assert.assertEquals(BLAS.asum(dv1), 3.0, TOL);
Assert.assertEquals(BLAS.asum(spv1), 3.0, TOL);
}

@Test
public void testScal() throws Exception {
DenseVector v1 = dv1.clone();
BLAS.scal(0.5, v1);
Assert.assertArrayEquals(v1.getData(), new double[]{0.5, 1.0}, TOL);

SparseVector v2 = spv1.clone();
BLAS.scal(0.5, v2);
Assert.assertArrayEquals(v2.getIndices(), spv1.getIndices());
Assert.assertArrayEquals(v2.getValues(), new double[]{0.5, 1.0}, TOL);
}

@Test
public void testDot() throws Exception {
DenseVector v = DenseVector.ones(2);
Assert.assertEquals(BLAS.dot(dv1, v), 3.0, TOL);
}

@Test
public void testAxpy() throws Exception {
DenseVector v = DenseVector.ones(2);
BLAS.axpy(1.0, dv1, v);
Assert.assertArrayEquals(v.getData(), new double[]{2, 3}, TOL);
BLAS.axpy(1.0, spv1, v);
Assert.assertArrayEquals(v.getData(), new double[]{3, 5}, TOL);
BLAS.axpy(1, 1.0, new double[]{1}, 0, v.getData(), 1);
Assert.assertArrayEquals(v.getData(), new double[]{3, 6}, TOL);
}

private DenseMatrix simpleMM(DenseMatrix m1, DenseMatrix m2) {
DenseMatrix mm = new DenseMatrix(m1.numRows(), m2.numCols());
for (int i = 0; i < m1.numRows(); i++) {
for (int j = 0; j < m2.numCols(); j++) {
double s = 0.;
for (int k = 0; k < m1.numCols(); k++) {
s += m1.get(i, k) * m2.get(k, j);
}
mm.set(i, j, s);
}
}
return mm;
}

@Test
public void testGemm() throws Exception {
DenseMatrix m32 = DenseMatrix.rand(3, 2);
DenseMatrix m24 = DenseMatrix.rand(2, 4);
DenseMatrix m34 = DenseMatrix.rand(3, 4);
DenseMatrix m42 = DenseMatrix.rand(4, 2);
DenseMatrix m43 = DenseMatrix.rand(4, 3);

DenseMatrix a34 = DenseMatrix.zeros(3, 4);
BLAS.gemm(1.0, m32, false, m24, false, 0., a34);
Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m24).getData(), TOL);

BLAS.gemm(1.0, m32, false, m42, true, 0., a34);
Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m42.transpose()).getData(), TOL);

DenseMatrix a24 = DenseMatrix.zeros(2, 4);
BLAS.gemm(1.0, m32, true, m34, false, 0., a24);
Assert.assertArrayEquals(a24.getData(), simpleMM(m32.transpose(), m34).getData(), TOL);

BLAS.gemm(1.0, m32, true, m43, true, 0., a24);
Assert.assertArrayEquals(a24.getData(), simpleMM(m32.transpose(), m43.transpose()).getData(), TOL);
}

@Test
public void testGemmSizeCheck() throws Exception {
thrown.expect(IllegalArgumentException.class);
DenseMatrix m32 = DenseMatrix.rand(3, 2);
DenseMatrix m42 = DenseMatrix.rand(4, 2);
DenseMatrix a34 = DenseMatrix.zeros(3, 4);
BLAS.gemm(1.0, m32, false, m42, false, 0., a34);
}

@Test
public void testGemmTransposeSizeCheck() throws Exception {
thrown.expect(IllegalArgumentException.class);
DenseMatrix m32 = DenseMatrix.rand(3, 2);
DenseMatrix m42 = DenseMatrix.rand(4, 2);
DenseMatrix a34 = DenseMatrix.zeros(3, 4);
BLAS.gemm(1.0, m32, true, m42, true, 0., a34);
}

@Test
public void testGemvDense() throws Exception {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the dense version not tested until now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just added more test cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing validator exception test cases:

  1. invalid dimension
  2. invalid dimension after transpose

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we just added the test cases.

DenseVector y1 = DenseVector.ones(2);
BLAS.gemv(2.0, mat, false, dv2, 0., y1);
Assert.assertArrayEquals(new double[]{28, 64}, y1.data, TOL);

DenseVector y2 = DenseVector.ones(2);
BLAS.gemv(2.0, mat, false, dv2, 1., y2);
Assert.assertArrayEquals(new double[]{29, 65}, y2.data, TOL);
}

@Test
public void testGemvDenseTranspose() throws Exception {
DenseVector y1 = DenseVector.ones(3);
BLAS.gemv(1.0, mat, true, dv1, 0., y1);
Assert.assertArrayEquals(new double[]{9, 12, 15}, y1.data, TOL);

DenseVector y2 = DenseVector.ones(3);
BLAS.gemv(1.0, mat, true, dv1, 1., y2);
Assert.assertArrayEquals(new double[]{10, 13, 16}, y2.data, TOL);
}

@Test
public void testGemvSparse() throws Exception {
DenseVector y1 = DenseVector.ones(2);
BLAS.gemv(2.0, mat, false, spv2, 0., y1);
Assert.assertArrayEquals(new double[]{20, 44}, y1.data, TOL);

DenseVector y2 = DenseVector.ones(2);
BLAS.gemv(2.0, mat, false, spv2, 1., y2);
Assert.assertArrayEquals(new double[]{21, 45}, y2.data, TOL);
}

@Test
public void testGemvSparseTranspose() throws Exception {
DenseVector y1 = DenseVector.ones(3);
BLAS.gemv(2.0, mat, true, spv1, 0., y1);
Assert.assertArrayEquals(new double[]{18, 24, 30}, y1.data, TOL);

DenseVector y2 = DenseVector.ones(3);
BLAS.gemv(2.0, mat, true, spv1, 1., y2);
Assert.assertArrayEquals(new double[]{19, 25, 31}, y2.data, TOL);
}

@Test
public void testGemvSizeCheck() throws Exception {
thrown.expect(IllegalArgumentException.class);
DenseVector y = DenseVector.ones(2);
BLAS.gemv(2.0, mat, false, dv1, 0., y);
}

@Test
public void testGemvTransposeSizeCheck() throws Exception {
thrown.expect(IllegalArgumentException.class);
DenseVector y = DenseVector.ones(2);
BLAS.gemv(2.0, mat, true, dv1, 0., y);
}
}