diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 56caeac05c0c1..43f48befd014f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -27,6 +27,7 @@ import breeze.numerics.{sqrt => brzSqrt} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD @@ -117,6 +118,7 @@ class RowMatrix @Since("1.0.0") ( // Computes n*(n+1)/2, avoiding overflow in the multiplication. // This succeeds when n <= 65535, which is checked above val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) + val gramianSizeInBytes = nt * 8L // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])( @@ -136,7 +138,8 @@ class RowMatrix @Since("1.0.0") ( U1 } else { U1 += U2 - } + }, + depth = getTreeAggregateIdealDepth(gramianSizeInBytes) ) RowMatrix.triuToFull(n, GU.data) @@ -775,6 +778,35 @@ class RowMatrix @Since("1.0.0") ( s"The number of rows $m is different from what specified or previously computed: ${nRows}.") } } + + /** + * Computing desired tree aggregate depth necessary to avoid exceeding + * driver.MaxResultSize during aggregation. + * Based on the formulae: (numPartitions)^(1/depth) * objectSize <= DriverMaxResultSize + * @param aggregatedObjectSizeInBytes the size, in megabytes, of the object being tree aggregated + */ + private[spark] def getTreeAggregateIdealDepth(aggregatedObjectSizeInBytes: Long) = { + require(aggregatedObjectSizeInBytes > 0, + "Cannot compute aggregate depth heuristic based on a zero-size object to aggregate") + + val maxDriverResultSizeInBytes = rows.conf.get[Long](MAX_RESULT_SIZE) + + require(maxDriverResultSizeInBytes > aggregatedObjectSizeInBytes, + s"Cannot aggregate object of size $aggregatedObjectSizeInBytes Bytes, " + + s"as it's bigger than maxResultSize ($maxDriverResultSizeInBytes Bytes)") + + val numerator = math.log(rows.getNumPartitions) + val denominator = math.log(maxDriverResultSizeInBytes) - math.log(aggregatedObjectSizeInBytes) + val desiredTreeDepth = math.ceil(numerator / denominator) + + if (desiredTreeDepth > 4) { + logWarning( + s"Desired tree depth for treeAggregation is big ($desiredTreeDepth)." + + "Consider increasing driver max result size or reducing number of partitions") + } + + math.min(math.max(1, desiredTreeDepth), 10).toInt + } } @Since("1.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index a4ca4f0a80faa..a0c4c68243e67 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -101,6 +101,26 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("getTreeAggregateIdealDepth") { + val nbPartitions = 100 + val vectors = sc.emptyRDD[Vector] + .repartition(nbPartitions) + val rowMat = new RowMatrix(vectors) + + assert(rowMat.getTreeAggregateIdealDepth(100 * 1024 * 1024) === 2) + assert(rowMat.getTreeAggregateIdealDepth(110 * 1024 * 1024) === 3) + assert(rowMat.getTreeAggregateIdealDepth(700 * 1024 * 1024) === 10) + + val zeroSizeException = intercept[Exception]{ + rowMat.getTreeAggregateIdealDepth(0) + } + assert(zeroSizeException.getMessage.contains("zero-size object to aggregate")) + val objectBiggerThanResultSize = intercept[Exception]{ + rowMat.getTreeAggregateIdealDepth(1100 * 1024 * 1024) + } + assert(objectBiggerThanResultSize.getMessage.contains("it's bigger than maxResultSize")) + } + test("similar columns") { val colMags = Vectors.dense(math.sqrt(126), math.sqrt(66), math.sqrt(94)) val expected = BDM(