Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6685][MLLIB]Use DSYRK to compute AtA in ALS #13891

Closed
wants to merge 12 commits into from

Conversation

hqzizania
Copy link
Contributor

What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-6685
This is to swtich DSPR to DSYRK to use native BLAS to accelerate the computation of AtA in ALS. A buffer is allocated to stack vectors to do Level 3 BLAS routine

How was this patch tested?

java and scala ut

@srowen
Copy link
Member

srowen commented Jun 24, 2016

Is this actually faster though?

@hqzizania
Copy link
Contributor Author

hqzizania commented Jun 24, 2016

This is a prototype. Actually, it is critical if it will be faster = =!
I have done a simple test, the effect is up to "number of user for each product". The "number of user for each product" is equal to the range of i in each loop. If a considerable number of vectors can be stack, its nativeBLAS will speeup > 3x than original nativeBLAS, but it is still slower than original F2JBLAS (maybe the data of my test is not enough big). In my test for the original ALS, nativeBLAS is much slower than F2JBLAS.
Anyway I don't know if "number of user for each product" can be very large in a real case.

@SparkQA
Copy link

SparkQA commented Jun 24, 2016

Test build #61171 has finished for PR 13891 at commit 8fb4a82.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@mengxr
Copy link
Contributor

mengxr commented Jun 24, 2016

@hqzizania This could be tested with benchmarks without ALS. I guess even with a correct implementation, we need a large rank to see improvement.

@hqzizania
Copy link
Contributor Author

@mengxr Do you mean only test add() and addStack() without ALS?

@SparkQA
Copy link

SparkQA commented Jun 24, 2016

Test build #61184 has finished for PR 13891 at commit 7e3d238.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hqzizania
Copy link
Contributor Author

code for testing

  def run(rank: Int, a:Int) = {
    println(s"blas.getclass() = ${blas.getClass.toString} on process $rank")

    val m = 1 << a
    val n = 1 << a - 1
    val stack = 1 << a - 2
    val matrix = new Array[Array[Float]](m).map { x =>
      val y = new Array[Float](n)
      y.map(a => Random.nextFloat())
    }
    val bVector = new Array[Double](m).map(x => Random.nextDouble())
    val ls = new NormalEquation(n)

    for (u <- 0 to 3) {
      ls.reset()
      val t0 = System.nanoTime()
      for (i <- 0 until m)
        ls.add(matrix(i), bVector(i))
      val t1 = System.nanoTime()
      println("nostack Elapsed time: " + (t1 - t0) / 1000000 + s"ms on process $rank")

      ls.reset()
      val t2 = System.nanoTime()
      var i = 0
      while (i < m) {
        val matrixBuffer = mutable.ArrayBuilder.make[Double]
        val bBuffer = mutable.ArrayBuilder.make[Double]
        for (s <- 0 until stack) {
          for (j <- 0 until n) {
            matrixBuffer += matrix(i + s)(j)
          }
          bBuffer += bVector(i + s)
        }
        i += stack
        ls.addStack(matrixBuffer.result(), bBuffer.result(), stack)
      }
      val t3 = System.nanoTime()
      println("stack Elapsed time: " + (t3 - t2) / 1000000 + s"ms on process $rank")
    }
  }

  class NormalEquation(val k: Int) extends Serializable {

    /** Number of entries in the upper triangular part of a k-by-k matrix. */
    val triK = k * (k + 1) / 2
    /** A^T^ * A */
    val ata = new Array[Double](triK)
    /** A^T^ * b */
    val atb = new Array[Double](k)

    private val da = new Array[Double](k)
    private val ata2 = new Array[Double](k * k)
    private val upper = "U"

    private def copyToDouble(a: Array[Float]): Unit = {
      var i = 0
      while (i < k) {
        da(i) = a(i)
        i += 1
      }
    }

    private def copyToTri(): Unit = {
      var ii = 0
      for(i <- 0 until k)
        for(j <- 0 to i) {
          ata(ii) += ata2(i * k + j)
          ata2(i * k + j) = 0
          ii += 1
        }
    }

    /** Adds an observation. */
    def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
      require(c >= 0.0)
      require(a.length == k)
      copyToDouble(a)
      blas.dspr(upper, k, c, da, 1, ata)
      if (b != 0.0) {
        blas.daxpy(k, c * b, da, 1, atb, 1)
      }
      this
    }

    /** Adds a stack of observations. */
    def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
      require(a.length == n * k)
      blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
      copyToTri()
      blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
      this
    }

    /** Merges another normal equation object. */
    def merge(other: NormalEquation): this.type = {
      require(other.k == k)
      blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
      blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
      this
    }

    /** Resets everything to zero, which should be called after each solve. */
    def reset(): Unit = {
      ju.Arrays.fill(ata, 0.0)
      ju.Arrays.fill(ata2, 0.0)
      ju.Arrays.fill(atb, 0.0)
    }
  }

results:

image

image

@hqzizania
Copy link
Contributor Author

hqzizania commented Jun 28, 2016

@mengxr this is a simple imitation of the loop in computeFactors[ID]() ALS using. It runs on a bare-metal node with 4 cores. All tests use all cores by RDD multi-partitions. Can it make sense for this patch?

@@ -1296,6 +1316,9 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
var i = srcPtrs(j)
var numExplicits = 0
val doStack = if (srcPtrs(j + 1) - srcPtrs(j) > 10) true else false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (...) true else false is redundant

@srowen
Copy link
Member

srowen commented Jun 28, 2016

It may not matter much but your test code is a little different than in the patch, like for copyToTri().
It's optional, but a few comments explaining what addStack does might help readers.

set stack size > 128 and comments added
@hqzizania
Copy link
Contributor Author

hqzizania commented Jun 28, 2016

@srowen Ops~ The copytoTri() is indeed a little different in the test code. I change it into:

    private def copyToTri(): Unit = {
      var i = 0
      var j = 0
      var ii = 0
      while (i < k) {
        val temp = i * k
        j = 0
        while (j <= i) {
          ata(ii) += ata2(temp + j)
          j += 1
          ii += 1
        }
        i += 1
      }
    }

And ls.reset() added into the loop of doStack as following (also for the other loop):

val t2 = System.nanoTime()
      var i = 0
      while (i < m) {
        ls.reset()
        val matrixBuffer = mutable.ArrayBuilder.make[Double]
        val bBuffer = mutable.ArrayBuilder.make[Double]
        for (s <- 0 until stack) {
          for (j <- 0 until n) {
            matrixBuffer += matrix(i + s)(j)
          }
          bBuffer += bVector(i + s)
        }
        i += stack
        ls.addStack(matrixBuffer.result(), bBuffer.result(), stack)
      }
      val t3 = System.nanoTime()

The results are basically the same. Actually ls.reset() is also used in each inner loop in computeFactors ALS using.

@hqzizania
Copy link
Contributor Author

hqzizania commented Jun 28, 2016

I set the threshold size to stack as 128 according to some more tests results, where 128 maybe a conservative size. However, this change will bypassing existing unit tests, as doStack is always false. This patch runs through all unit tests successfully at my local machine when setting it 10 (doStack is true sometimes). Thus, I add a unit test for it.

@SparkQA
Copy link

SparkQA commented Jun 28, 2016

Test build #61381 has finished for PR 13891 at commit 3607bdc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jun 29, 2016

Test build #61465 has finished for PR 13891 at commit 56194eb.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hqzizania
Copy link
Contributor Author

cc @mengxr @yanboliang Was this patch Okay?

@yanboliang
Copy link
Contributor

yanboliang commented Sep 21, 2016

@hqzizania Could you share the regression performance test result? I have time to get this in if it's ready. Thanks.

@hqzizania
Copy link
Contributor Author

@yanboliang sorry, i'm on a business trip and will upload the test result ASAP.

@hqzizania
Copy link
Contributor Author

hqzizania commented Oct 21, 2016

@yanboliang So sorry for my late response.

Some regression performance test results:
Datasets: using genExplicitTestData to generate with numUsers = 20000, numItems = 2000
Single-node cluster: 16 physical cores, 100GB memory
ALS: numUserBlocks = 30, numItemBlocks = 30
It will run computeFactors with 30 partitions in parallel.

ALS: rank = 1024
Computing time and used memory for computeFactors:
image

ALS: rank = 129
Computing time for computeFactors:
image

ALS: rank = 512
Computing time for computeFactors:
image

The results shows this patch makes it faster very much when rank is large, but we should reset the two threshold values of "doStack" as 1024 rather 128.
However, a following problem is that the unit test for this patch will take much time as rank must be larger than 1024. Should I just remove the unit test?

@mengxr
Copy link
Contributor

mengxr commented Oct 21, 2016

@hqzizania Thanks for the performance tests! This matches my guess. I'm not sure how often people use a rank greater than 1000 or even 250. But I think it is good to use BLAS level-3 routines. We can make the threshold a param and set a small threshold and test both code paths.

@hqzizania
Copy link
Contributor Author

@mengxr I see. I will add a param for it. :)

@SparkQA
Copy link

SparkQA commented Oct 21, 2016

Test build #67323 has finished for PR 13891 at commit dc4f4ba.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 24, 2016

Test build #67422 has finished for PR 13891 at commit d29fd67.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 24, 2016

Test build #67423 has finished for PR 13891 at commit 294164d.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 24, 2016

Test build #67424 has finished for PR 13891 at commit 513e791.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 24, 2016

Test build #67432 has finished for PR 13891 at commit 1f3ff96.

  • This patch fails MiMa tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 24, 2016

Test build #67434 has finished for PR 13891 at commit 1081e64.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 24, 2016

Test build #67435 has finished for PR 13891 at commit a6b5a16.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hqzizania
Copy link
Contributor Author

@mengxr @srowen @yanboliang A threshold param is added for unit tests. Does it look okay now?

@@ -864,6 +864,9 @@ object MimaExcludes {
// [SPARK-12221] Add CPU time to metrics
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
) ++ Seq(
// SPARK-6685
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.recommendation.ALS.train")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually remove a method? that shouldn't be necessary, I imagine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't know how to write the "mima exclude" exactly. It could be not a proper solution to the failure of mima, which may be caused by the modification to def train

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's a developer API, so more reasonable to change, though it's still ideal to not change these APIs unless necessary. Try putting the new param at the end? I don't think that changes the situation but makes it at least source-compatible with any current invocations.

@SparkQA
Copy link

SparkQA commented Oct 25, 2016

Test build #67523 has finished for PR 13891 at commit 2077457.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hqzizania
Copy link
Contributor Author

@srowen It seems mima test still fails when putting the new Param at the end of train method. :(

@srowen
Copy link
Member

srowen commented Oct 26, 2016

@hqzizania OK thanks for checking that. That may be an issue for this change.

@hqzizania
Copy link
Contributor Author

@srowen I am not familiar with MiMa really, so what should I do now? Or just go back to the previous commit, and create a JIRA for the issue?

@HyukjinKwon
Copy link
Member

@hqzizania If you check the log, there are some guides for how to. Should we maybe rebase this and check the logs?

@HyukjinKwon
Copy link
Member

I will propose to close this assuming this is inactive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants