-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-26881][mllib] Heuristic for tree aggregate depth #23983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ import breeze.numerics.{sqrt => brzSqrt} | |
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.internal.config.MAX_RESULT_SIZE | ||
| import org.apache.spark.mllib.linalg._ | ||
| import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} | ||
| import org.apache.spark.rdd.RDD | ||
|
|
@@ -117,6 +118,7 @@ class RowMatrix @Since("1.0.0") ( | |
| // Computes n*(n+1)/2, avoiding overflow in the multiplication. | ||
| // This succeeds when n <= 65535, which is checked above | ||
| val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) | ||
| val gramianSizeInBytes = nt * 8L | ||
|
|
||
| // Compute the upper triangular part of the gram matrix. | ||
| val GU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])( | ||
|
|
@@ -136,7 +138,8 @@ class RowMatrix @Since("1.0.0") ( | |
| U1 | ||
| } else { | ||
| U1 += U2 | ||
| } | ||
| }, | ||
| depth = getTreeAggregateIdealDepth(gramianSizeInBytes) | ||
| ) | ||
|
|
||
| RowMatrix.triuToFull(n, GU.data) | ||
|
|
@@ -775,6 +778,35 @@ class RowMatrix @Since("1.0.0") ( | |
| s"The number of rows $m is different from what specified or previously computed: ${nRows}.") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Computing desired tree aggregate depth necessary to avoid exceeding | ||
| * driver.MaxResultSize during aggregation. | ||
| * Based on the formulae: (numPartitions)^(1/depth) * objectSize <= DriverMaxResultSize | ||
| * @param aggregatedObjectSizeInBytes the size, in megabytes, of the object being tree aggregated | ||
| */ | ||
| private[spark] def getTreeAggregateIdealDepth(aggregatedObjectSizeInBytes: Long) = { | ||
| require(aggregatedObjectSizeInBytes > 0, | ||
| "Cannot compute aggregate depth heuristic based on a zero-size object to aggregate") | ||
|
|
||
| val maxDriverResultSizeInBytes = rows.conf.get[Long](MAX_RESULT_SIZE) | ||
srowen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| require(maxDriverResultSizeInBytes > aggregatedObjectSizeInBytes, | ||
|
||
| s"Cannot aggregate object of size $aggregatedObjectSizeInBytes Bytes, " | ||
| + s"as it's bigger than maxResultSize ($maxDriverResultSizeInBytes Bytes)") | ||
|
|
||
| val numerator = math.log(rows.getNumPartitions) | ||
| val denominator = math.log(maxDriverResultSizeInBytes) - math.log(aggregatedObjectSizeInBytes) | ||
| val desiredTreeDepth = math.ceil(numerator / denominator) | ||
|
|
||
| if (desiredTreeDepth > 4) { | ||
| logWarning( | ||
| s"Desired tree depth for treeAggregation is big ($desiredTreeDepth)." | ||
| + "Consider increasing driver max result size or reducing number of partitions") | ||
| } | ||
|
|
||
| math.min(math.max(1, desiredTreeDepth), 10).toInt | ||
| } | ||
| } | ||
|
|
||
| @Since("1.0.0") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this can happen, but it's an OK check.