Skip to content

Commit c481bdf

Browse files
Davies Liudavies
authored andcommitted
[SPARK-13329] [SQL] considering output for statistics of logical plan
The current implementation of statistics of UnaryNode does not considering output (for example, Project may product much less columns than it's child), we should considering it to have a better guess. We usually only join with few columns from a parquet table, the size of projected plan could be much smaller than the original parquet files. Having a better guess of size help we choose between broadcast join or sort merge join. After this PR, I saw a few queries choose broadcast join other than sort merge join without turning spark.sql.autoBroadcastJoinThreshold for every query, ended up with about 6-8X improvements on end-to-end time. We use `defaultSize` of DataType to estimate the size of a column, currently For DecimalType/StringType/BinaryType and UDT, we are over-estimate too much (4096 Bytes), so this PR change them to some more reasonable values. Here are the new defaultSize for them: DecimalType: 8 or 16 bytes, based on the precision StringType: 20 bytes BinaryType: 100 bytes UDF: default size of SQL type These numbers are not perfect (hard to have a perfect number for them), but should be better than 4096. Author: Davies Liu <davies@databricks.com> Closes #11210 from davies/statics.
1 parent c5bfe5d commit c481bdf

File tree

11 files changed

+100
-56
lines changed

11 files changed

+100
-56
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,8 @@ def alias(self, alias):
551551
>>> df_as1 = df.alias("df_as1")
552552
>>> df_as2 = df.alias("df_as2")
553553
>>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner')
554-
>>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect()
555-
[Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)]
554+
>>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").collect()
555+
[Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)]
556556
"""
557557
assert isinstance(alias, basestring), "alias should be a string"
558558
return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,21 @@ abstract class UnaryNode extends LogicalPlan {
316316
override def children: Seq[LogicalPlan] = child :: Nil
317317

318318
override protected def validConstraints: Set[Expression] = child.constraints
319+
320+
override def statistics: Statistics = {
321+
// There should be some overhead in Row object, the size should not be zero when there is
322+
// no columns, this help to prevent divide-by-zero error.
323+
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
324+
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
325+
// Assume there will be the same number of rows as child has.
326+
var sizeInBytes = (child.statistics.sizeInBytes * outputRowSize) / childRowSize
327+
if (sizeInBytes == 0) {
328+
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
329+
// (product of children).
330+
sizeInBytes = 1
331+
}
332+
Statistics(sizeInBytes = sizeInBytes)
333+
}
319334
}
320335

321336
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
176176
Some(children.flatMap(_.maxRows).min)
177177
}
178178
}
179+
180+
override def statistics: Statistics = {
181+
val leftSize = left.statistics.sizeInBytes
182+
val rightSize = right.statistics.sizeInBytes
183+
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
184+
Statistics(sizeInBytes = sizeInBytes)
185+
}
179186
}
180187

181188
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
@@ -188,6 +195,10 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
188195
childrenResolved &&
189196
left.output.length == right.output.length &&
190197
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
198+
199+
override def statistics: Statistics = {
200+
Statistics(sizeInBytes = left.statistics.sizeInBytes)
201+
}
191202
}
192203

193204
/** Factory for constructing new `Union` nodes. */
@@ -426,6 +437,14 @@ case class Aggregate(
426437

427438
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
428439
override def maxRows: Option[Long] = child.maxRows
440+
441+
override def statistics: Statistics = {
442+
if (groupingExpressions.isEmpty) {
443+
Statistics(sizeInBytes = 1)
444+
} else {
445+
super.statistics
446+
}
447+
}
429448
}
430449

431450
case class Window(
@@ -521,9 +540,7 @@ case class Expand(
521540
AttributeSet(projections.flatten.flatMap(_.references))
522541

523542
override def statistics: Statistics = {
524-
// TODO shouldn't we factor in the size of the projection versus the size of the backing child
525-
// row?
526-
val sizeInBytes = child.statistics.sizeInBytes * projections.length
543+
val sizeInBytes = super.statistics.sizeInBytes * projections.length
527544
Statistics(sizeInBytes = sizeInBytes)
528545
}
529546
}
@@ -648,6 +665,17 @@ case class Sample(
648665
val isTableSample: java.lang.Boolean = false) extends UnaryNode {
649666

650667
override def output: Seq[Attribute] = child.output
668+
669+
override def statistics: Statistics = {
670+
val ratio = upperBound - lowerBound
671+
// BigInt can't multiply with Double
672+
var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100
673+
if (sizeInBytes == 0) {
674+
sizeInBytes = 1
675+
}
676+
Statistics(sizeInBytes = sizeInBytes)
677+
}
678+
651679
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
652680
}
653681

sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class BinaryType private() extends AtomicType {
4747
}
4848

4949
/**
50-
* The default size of a value of the BinaryType is 4096 bytes.
50+
* The default size of a value of the BinaryType is 100 bytes.
5151
*/
52-
override def defaultSize: Int = 4096
52+
override def defaultSize: Int = 100
5353

5454
private[spark] override def asNullable: BinaryType = this
5555
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
9191
}
9292

9393
/**
94-
* The default size of a value of the DecimalType is 4096 bytes.
94+
* The default size of a value of the DecimalType is 8 bytes (precision <= 18) or 16 bytes.
9595
*/
96-
override def defaultSize: Int = 4096
96+
override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16
9797

9898
override def simpleString: String = s"decimal($precision,$scale)"
9999

sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ class StringType private() extends AtomicType {
3838
private[sql] val ordering = implicitly[Ordering[InternalType]]
3939

4040
/**
41-
* The default size of a value of the StringType is 4096 bytes.
41+
* The default size of a value of the StringType is 20 bytes.
4242
*/
43-
override def defaultSize: Int = 4096
43+
override def defaultSize: Int = 20
4444

4545
private[spark] override def asNullable: StringType = this
4646
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
7171
*/
7272
def userClass: java.lang.Class[UserType]
7373

74-
/**
75-
* The default size of a value of the UserDefinedType is 4096 bytes.
76-
*/
77-
override def defaultSize: Int = 4096
74+
override def defaultSize: Int = sqlType.defaultSize
7875

7976
/**
8077
* For UDT, asNullable will not change the nullability of its internal sqlType and just returns

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,15 +248,15 @@ class DataTypeSuite extends SparkFunSuite {
248248
checkDefaultSize(LongType, 8)
249249
checkDefaultSize(FloatType, 4)
250250
checkDefaultSize(DoubleType, 8)
251-
checkDefaultSize(DecimalType(10, 5), 4096)
252-
checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096)
251+
checkDefaultSize(DecimalType(10, 5), 8)
252+
checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 16)
253253
checkDefaultSize(DateType, 4)
254254
checkDefaultSize(TimestampType, 8)
255-
checkDefaultSize(StringType, 4096)
256-
checkDefaultSize(BinaryType, 4096)
255+
checkDefaultSize(StringType, 20)
256+
checkDefaultSize(BinaryType, 100)
257257
checkDefaultSize(ArrayType(DoubleType, true), 800)
258-
checkDefaultSize(ArrayType(StringType, false), 409600)
259-
checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
258+
checkDefaultSize(ArrayType(StringType, false), 2000)
259+
checkDefaultSize(MapType(IntegerType, StringType, true), 2400)
260260
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
261261
checkDefaultSize(structType, 812)
262262

sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import scala.collection.immutable.IndexedSeq
21-
2220
import org.apache.spark.rdd.RDD
2321
import org.apache.spark.sql.catalyst.InternalRow
2422
import org.apache.spark.sql.catalyst.errors._

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -63,36 +63,40 @@ class JoinSuite extends QueryTest with SharedSQLContext {
6363
test("join operator selection") {
6464
sqlContext.cacheManager.clearCache()
6565

66-
Seq(
67-
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
68-
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
69-
("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
70-
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
71-
("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
72-
("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
73-
("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
74-
("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
75-
("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
76-
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
77-
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
78-
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
79-
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
80-
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
81-
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
82-
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
83-
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
84-
classOf[SortMergeJoin]), // converted from Right Outer to Inner
85-
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
86-
classOf[SortMergeOuterJoin]),
87-
("SELECT * FROM testData full outer join testData2 ON key = a",
88-
classOf[SortMergeOuterJoin]),
89-
("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
90-
classOf[BroadcastNestedLoopJoin]),
91-
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
92-
classOf[BroadcastNestedLoopJoin]),
93-
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
94-
classOf[BroadcastNestedLoopJoin])
95-
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
66+
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
67+
Seq(
68+
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
69+
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
70+
("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
71+
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
72+
("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
73+
("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
74+
("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
75+
("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
76+
("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
77+
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
78+
classOf[CartesianProduct]),
79+
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
80+
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
81+
classOf[CartesianProduct]),
82+
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
83+
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
84+
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
85+
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
86+
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
87+
classOf[SortMergeJoin]),
88+
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
89+
classOf[SortMergeOuterJoin]),
90+
("SELECT * FROM testData full outer join testData2 ON key = a",
91+
classOf[SortMergeOuterJoin]),
92+
("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
93+
classOf[BroadcastNestedLoopJoin]),
94+
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
95+
classOf[BroadcastNestedLoopJoin]),
96+
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
97+
classOf[BroadcastNestedLoopJoin])
98+
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
99+
}
96100
}
97101

98102
// ignore("SortMergeJoin shouldn't work on unsortable columns") {
@@ -118,9 +122,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
118122
test("broadcasted hash outer join operator selection") {
119123
sqlContext.cacheManager.clearCache()
120124
sql("CACHE TABLE testData")
125+
sql("CACHE TABLE testData2")
121126
Seq(
122127
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
123-
classOf[SortMergeOuterJoin]),
128+
classOf[BroadcastHashJoin]),
124129
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
125130
classOf[BroadcastHashJoin]),
126131
("SELECT * FROM testData right join testData2 ON key = a and key = 2",

0 commit comments

Comments
 (0)