Skip to content

Commit

Permalink
[SPARK-34906] Refactor TreeNode's children handling methods into spec…
Browse files Browse the repository at this point in the history
…ialized traits

### What changes were proposed in this pull request?
Spark query plan node hierarchy has specialized traits (or abstract classes) for handling nodes with fixed number of children, for example `UnaryExpression`, `UnaryNode` and `UnaryExec` for representing an expression, a logical plan and a physical plan with only one child, respectively. This PR refactors the `TreeNode` hierarchy by extracting the children handling functionality into the following traits. `UnaryExpression` and other similar classes now extend the corresponding new trait:
```
trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] =>
  override final def children: Seq[T] = Nil
}

trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
  def child: T
  transient override final lazy val children: Seq[T] = child :: Nil
}

trait BinaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
  def left: T
  def right: T
  transient override final lazy val children: Seq[T] = left :: right :: Nil
}

trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
  def first: T
  def second: T
  def third: T
  transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
}
```

This refactoring, which is part of a bigger effort to make tree transformations in Spark more efficient, has two benefits:
- It moves the children handling methods to a single place, instead of being spread in specific subclasses, which will help the future optimizations for tree traversals.
- It allows to mix in these traits with some concrete node types that could not extend the previous classes. For example, expressions with one child that extend `AggregateFunction` cannot extend `UnaryExpression` as `AggregateFunction` defines the `foldable` method final while `UnaryExpression` defines it as non final. With the new traits, we can directly extend the concrete class from `UnaryLike` in these cases. Classes with more specific child handling will make tree traversal methods faster.

In this PR we have also updated many concrete node types to extend these traits to benefit from more specific child handling.

### Why are the changes needed?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

This is a refactoring, passes existing tests.

Closes #31932 from dbaliafroozeh/FactorOutChildHandlnigIntoSeparateTraits.

Authored-by: Ali Afroozeh <ali.afroozeh@databricks.com>
Signed-off-by: herman <herman@databricks.com>
  • Loading branch information
dbaliafroozeh authored and hvanhovell committed Mar 30, 2021
1 parent 9f065ff commit bd0990e
Show file tree
Hide file tree
Showing 50 changed files with 352 additions and 332 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.UnaryLike

trait DynamicPruning extends Predicate

Expand All @@ -46,9 +47,10 @@ case class DynamicPruningSubquery(
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId)
with DynamicPruning
with Unevaluable {
with Unevaluable
with UnaryLike[Expression] {

override def children: Seq[Expression] = Seq(pruningKey)
override def child: Expression = pruningKey

override def plan: LogicalPlan = buildQuery

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -451,21 +451,14 @@ trait Stateful extends Nondeterministic {
/**
* A leaf expression, i.e. one without any child expressions.
*/
abstract class LeafExpression extends Expression {

override final def children: Seq[Expression] = Nil
}
abstract class LeafExpression extends Expression with LeafLike[Expression]


/**
* An expression with one input and one output. The output is by default evaluated to null
* if the input is evaluated to null.
*/
abstract class UnaryExpression extends Expression {

def child: Expression

override final def children: Seq[Expression] = child :: Nil
abstract class UnaryExpression extends Expression with UnaryLike[Expression] {

override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
Expand Down Expand Up @@ -552,12 +545,7 @@ object UnaryExpression {
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class BinaryExpression extends Expression {

def left: Expression
def right: Expression

override final def children: Seq[Expression] = Seq(left, right)
abstract class BinaryExpression extends Expression with BinaryLike[Expression] {

override def foldable: Boolean = left.foldable && right.foldable

Expand Down Expand Up @@ -701,7 +689,7 @@ object BinaryOperator {
* An expression with three inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class TernaryExpression extends Expression {
abstract class TernaryExpression extends Expression with TernaryLike[Expression] {

override def foldable: Boolean = children.forall(_.foldable)

Expand All @@ -712,12 +700,11 @@ abstract class TernaryExpression extends Expression {
* If subclass of TernaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val value1 = exprs(0).eval(input)
val value1 = first.eval(input)
if (value1 != null) {
val value2 = exprs(1).eval(input)
val value2 = second.eval(input)
if (value2 != null) {
val value3 = exprs(2).eval(input)
val value3 = third.eval(input)
if (value3 != null) {
return nullSafeEval(value1, value2, value3)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand All @@ -32,7 +33,8 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
* df.writeTo("catalog.db.table").partitionedBy($"category", days($"timestamp")).create()
* }}}
*/
abstract class PartitionTransformExpression extends Expression with Unevaluable {
abstract class PartitionTransformExpression extends Expression with Unevaluable
with UnaryLike[Expression] {
override def nullable: Boolean = true
}

Expand All @@ -41,37 +43,32 @@ abstract class PartitionTransformExpression extends Expression with Unevaluable
*/
case class Years(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform months.
*/
case class Months(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform days.
*/
case class Days(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform hours.
*/
case class Hours(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform bucket.
*/
case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(numBuckets, child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,10 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
case class ExpressionProxy(
child: Expression,
id: Int,
runtime: SubExprEvaluationRuntime) extends Expression {
runtime: SubExprEvaluationRuntime) extends UnaryExpression {

final override def dataType: DataType = child.dataType
final override def nullable: Boolean = child.nullable
final override def children: Seq[Expression] = child :: Nil

// `ExpressionProxy` is for interpreted expression evaluation only. So cannot `doGenCode`.
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
Expand All @@ -48,7 +49,7 @@ case class ApproxCountDistinctForIntervals(
relativeSD: Double = 0.05,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes {
extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes with BinaryLike[Expression] {

def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = {
this(
Expand Down Expand Up @@ -213,7 +214,8 @@ case class ApproxCountDistinctForIntervals(
copy(inputAggBufferOffset = newInputAggBufferOffset)
}

override def children: Seq[Expression] = Seq(child, endpointsExpression)
override def left: Expression = child
override def right: Expression = endpointsExpression

override def nullable: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
Expand Down Expand Up @@ -76,7 +77,8 @@ case class ApproximatePercentile(
accuracyExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes
with TernaryLike[Expression] {

def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
this(child, percentageExpression, accuracyExpression, 0, 0)
Expand Down Expand Up @@ -182,7 +184,9 @@ case class ApproximatePercentile(
override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile =
copy(inputAggBufferOffset = newOffset)

override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression)
override def first: Expression = child
override def second: Expression = percentageExpression
override def third: Expression = accuracyExpression

// Returns null for empty inputs
override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand All @@ -34,12 +35,11 @@ import org.apache.spark.sql.types._
""",
group = "agg_funcs",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {

override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")

override def children: Seq[Expression] = child :: Nil

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand All @@ -45,14 +46,13 @@ import org.apache.spark.sql.types._
* @param child to compute central moments of.
*/
abstract class CentralMomentAgg(child: Expression, nullOnDivideByZero: Boolean)
extends DeclarativeAggregate with ImplicitCastInputTypes {
extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] {

/**
* The central moment order to be computed.
*/
protected def momentOrder: Int

override def children: Seq[Expression] = Seq(child)
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand All @@ -30,9 +31,10 @@ import org.apache.spark.sql.types._
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
abstract class PearsonCorrelation(x: Expression, y: Expression, nullOnDivideByZero: Boolean)
extends DeclarativeAggregate with ImplicitCastInputTypes {
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {

override def children: Seq[Expression] = Seq(x, y)
override def left: Expression = x
override def right: Expression = y
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType}

@ExpressionDescription(
Expand All @@ -34,10 +35,12 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long
""",
group = "agg_funcs",
since = "3.0.0")
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes {
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {

override def prettyName: String = "count_if"

override def children: Seq[Expression] = Seq(predicate)
override def child: Expression = predicate

override def nullable: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand All @@ -27,9 +28,10 @@ import org.apache.spark.sql.types._
* When applied on empty data (i.e., count is zero), it returns NULL.
*/
abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean)
extends DeclarativeAggregate with ImplicitCastInputTypes {
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {

override def children: Seq[Expression] = Seq(x, y)
override def left: Expression = x
override def right: Expression = y
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
Expand Down Expand Up @@ -97,8 +99,8 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool
group = "agg_funcs",
since = "2.0.0")
case class CovPopulation(
left: Expression,
right: Expression,
override val left: Expression,
override val right: Expression,
nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate)
extends Covariance(left, right, nullOnDivideByZero) {

Expand All @@ -122,8 +124,8 @@ case class CovPopulation(
group = "agg_funcs",
since = "2.0.0")
case class CovSample(
left: Expression,
right: Expression,
override val left: Expression,
override val right: Expression,
nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate)
extends Covariance(left, right, nullOnDivideByZero) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -51,16 +52,14 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "2.0.0")
case class First(child: Expression, ignoreNulls: Boolean)
extends DeclarativeAggregate with ExpectsInputTypes {
extends DeclarativeAggregate with ExpectsInputTypes with UnaryLike[Expression] {

def this(child: Expression) = this(child, false)

def this(child: Expression, ignoreNullsExpr: Expression) = {
this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "first"))
}

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

// First is not a deterministic function.
Expand Down
Loading

0 comments on commit bd0990e

Please sign in to comment.