From bd0990e3e813d17065c593fc74f383b494fe8146 Mon Sep 17 00:00:00 2001 From: Ali Afroozeh Date: Tue, 30 Mar 2021 20:43:18 +0200 Subject: [PATCH] [SPARK-34906] Refactor TreeNode's children handling methods into specialized traits ### What changes were proposed in this pull request? Spark query plan node hierarchy has specialized traits (or abstract classes) for handling nodes with fixed number of children, for example `UnaryExpression`, `UnaryNode` and `UnaryExec` for representing an expression, a logical plan and a physical plan with only one child, respectively. This PR refactors the `TreeNode` hierarchy by extracting the children handling functionality into the following traits. `UnaryExpression` and other similar classes now extend the corresponding new trait: ``` trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] => override final def children: Seq[T] = Nil } trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def child: T transient override final lazy val children: Seq[T] = child :: Nil } trait BinaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def left: T def right: T transient override final lazy val children: Seq[T] = left :: right :: Nil } trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def first: T def second: T def third: T transient override final lazy val children: Seq[T] = first :: second :: third :: Nil } ``` This refactoring, which is part of a bigger effort to make tree transformations in Spark more efficient, has two benefits: - It moves the children handling methods to a single place, instead of being spread in specific subclasses, which will help the future optimizations for tree traversals. - It allows to mix in these traits with some concrete node types that could not extend the previous classes. For example, expressions with one child that extend `AggregateFunction` cannot extend `UnaryExpression` as `AggregateFunction` defines the `foldable` method final while `UnaryExpression` defines it as non final. With the new traits, we can directly extend the concrete class from `UnaryLike` in these cases. Classes with more specific child handling will make tree traversal methods faster. In this PR we have also updated many concrete node types to extend these traits to benefit from more specific child handling. ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? This is a refactoring, passes existing tests. Closes #31932 from dbaliafroozeh/FactorOutChildHandlnigIntoSeparateTraits. Authored-by: Ali Afroozeh Signed-off-by: herman --- .../catalyst/expressions/DynamicPruning.scala | 6 +- .../sql/catalyst/expressions/Expression.scala | 29 +-- .../expressions/PartitionTransforms.scala | 9 +- .../SubExprEvaluationRuntime.scala | 3 +- .../ApproxCountDistinctForIntervals.scala | 6 +- .../aggregate/ApproximatePercentile.scala | 8 +- .../expressions/aggregate/Average.scala | 6 +- .../aggregate/CentralMomentAgg.scala | 4 +- .../catalyst/expressions/aggregate/Corr.scala | 6 +- .../expressions/aggregate/CountIf.scala | 7 +- .../expressions/aggregate/Covariance.scala | 14 +- .../expressions/aggregate/First.scala | 5 +- .../aggregate/HyperLogLogPlusPlus.scala | 5 +- .../catalyst/expressions/aggregate/Last.scala | 5 +- .../catalyst/expressions/aggregate/Max.scala | 5 +- .../expressions/aggregate/MaxByAndMinBy.scala | 6 +- .../catalyst/expressions/aggregate/Min.scala | 5 +- .../expressions/aggregate/Percentile.scala | 10 +- .../expressions/aggregate/PivotFirst.scala | 6 +- .../expressions/aggregate/Product.scala | 5 +- .../catalyst/expressions/aggregate/Sum.scala | 6 +- .../aggregate/UnevaluableAggs.scala | 5 +- .../aggregate/bitwiseAggregates.scala | 6 +- .../expressions/aggregate/collect.scala | 6 +- .../expressions/aggregate/interfaces.scala | 3 +- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeCreator.scala | 9 +- .../expressions/conditionalExpressions.scala | 7 +- .../expressions/datetimeExpressions.scala | 8 +- .../sql/catalyst/expressions/grouping.scala | 5 +- .../expressions/mathExpressions.scala | 4 +- .../expressions/objects/objects.scala | 8 +- .../expressions/regexpExpressions.scala | 8 +- .../expressions/stringExpressions.scala | 36 ++- .../expressions/windowExpressions.scala | 51 +++-- .../sql/catalyst/plans/logical/Command.scala | 6 +- .../catalyst/plans/logical/LogicalPlan.scala | 17 +- .../catalyst/plans/logical/statements.scala | 45 ++-- .../catalyst/plans/logical/v2Commands.scala | 214 +++++++----------- .../catalyst/streaming/WriteToStream.scala | 7 +- .../streaming/WriteToStreamStatement.scala | 6 +- .../spark/sql/catalyst/trees/TreeNode.scala | 22 ++ .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 4 +- .../spark/sql/execution/SparkPlan.scala | 15 +- .../sql/execution/command/commands.scala | 4 +- .../v2/WriteToDataSourceV2Exec.scala | 6 +- .../WriteToContinuousDataSource.scala | 6 +- .../sources/WriteToMicroBatchDataSource.scala | 6 +- .../apache/spark/sql/execution/subquery.scala | 8 +- 50 files changed, 352 insertions(+), 332 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index 7065d27517e52..550fa4c3f73e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.UnaryLike trait DynamicPruning extends Predicate @@ -46,9 +47,10 @@ case class DynamicPruningSubquery( exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId) with DynamicPruning - with Unevaluable { + with Unevaluable + with UnaryLike[Expression] { - override def children: Seq[Expression] = Seq(pruningKey) + override def child: Expression = pruningKey override def plan: LogicalPlan = buildQuery diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 83f77a6abd490..42892e25fa0e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -451,21 +451,14 @@ trait Stateful extends Nondeterministic { /** * A leaf expression, i.e. one without any child expressions. */ -abstract class LeafExpression extends Expression { - - override final def children: Seq[Expression] = Nil -} +abstract class LeafExpression extends Expression with LeafLike[Expression] /** * An expression with one input and one output. The output is by default evaluated to null * if the input is evaluated to null. */ -abstract class UnaryExpression extends Expression { - - def child: Expression - - override final def children: Seq[Expression] = child :: Nil +abstract class UnaryExpression extends Expression with UnaryLike[Expression] { override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable @@ -552,12 +545,7 @@ object UnaryExpression { * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. */ -abstract class BinaryExpression extends Expression { - - def left: Expression - def right: Expression - - override final def children: Seq[Expression] = Seq(left, right) +abstract class BinaryExpression extends Expression with BinaryLike[Expression] { override def foldable: Boolean = left.foldable && right.foldable @@ -701,7 +689,7 @@ object BinaryOperator { * An expression with three inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. */ -abstract class TernaryExpression extends Expression { +abstract class TernaryExpression extends Expression with TernaryLike[Expression] { override def foldable: Boolean = children.forall(_.foldable) @@ -712,12 +700,11 @@ abstract class TernaryExpression extends Expression { * If subclass of TernaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { - val exprs = children - val value1 = exprs(0).eval(input) + val value1 = first.eval(input) if (value1 != null) { - val value2 = exprs(1).eval(input) + val value2 = second.eval(input) if (value2 != null) { - val value3 = exprs(2).eval(input) + val value3 = third.eval(input) if (value3 != null) { return nullSafeEval(value1, value2, value3) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala index e48fd8adaef09..9d34368b6c541 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -32,7 +33,8 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * df.writeTo("catalog.db.table").partitionedBy($"category", days($"timestamp")).create() * }}} */ -abstract class PartitionTransformExpression extends Expression with Unevaluable { +abstract class PartitionTransformExpression extends Expression with Unevaluable + with UnaryLike[Expression] { override def nullable: Boolean = true } @@ -41,7 +43,6 @@ abstract class PartitionTransformExpression extends Expression with Unevaluable */ case class Years(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType - override def children: Seq[Expression] = Seq(child) } /** @@ -49,7 +50,6 @@ case class Years(child: Expression) extends PartitionTransformExpression { */ case class Months(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType - override def children: Seq[Expression] = Seq(child) } /** @@ -57,7 +57,6 @@ case class Months(child: Expression) extends PartitionTransformExpression { */ case class Days(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType - override def children: Seq[Expression] = Seq(child) } /** @@ -65,7 +64,6 @@ case class Days(child: Expression) extends PartitionTransformExpression { */ case class Hours(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType - override def children: Seq[Expression] = Seq(child) } /** @@ -73,5 +71,4 @@ case class Hours(child: Expression) extends PartitionTransformExpression { */ case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType - override def children: Seq[Expression] = Seq(numBuckets, child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala index 1f239b696d5ff..a1f7ba3008775 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -121,11 +121,10 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) { case class ExpressionProxy( child: Expression, id: Int, - runtime: SubExprEvaluationRuntime) extends Expression { + runtime: SubExprEvaluationRuntime) extends UnaryExpression { final override def dataType: DataType = child.dataType final override def nullable: Boolean = child.nullable - final override def children: Seq[Expression] = child :: Nil // `ExpressionProxy` is for interpreted expression evaluation only. So cannot `doGenCode`. final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 103f55e58febd..42dc6f6b200d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -48,7 +49,7 @@ case class ApproxCountDistinctForIntervals( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes { + extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes with BinaryLike[Expression] { def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = { this( @@ -213,7 +214,8 @@ case class ApproxCountDistinctForIntervals( copy(inputAggBufferOffset = newInputAggBufferOffset) } - override def children: Seq[Expression] = Seq(child, endpointsExpression) + override def left: Expression = child + override def right: Expression = endpointsExpression override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 9406a97a82208..4e4a06a628453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest +import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} @@ -76,7 +77,8 @@ case class ApproximatePercentile( accuracyExpression: Expression, override val mutableAggBufferOffset: Int, override val inputAggBufferOffset: Int) - extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes { + extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes + with TernaryLike[Expression] { def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { this(child, percentageExpression, accuracyExpression, 0, 0) @@ -182,7 +184,9 @@ case class ApproximatePercentile( override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile = copy(inputAggBufferOffset = newOffset) - override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression) + override def first: Expression = child + override def second: Expression = percentageExpression + override def third: Expression = accuracyExpression // Returns null for empty inputs override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 13f38ac7c9ae5..490e14afe992b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -34,12 +35,11 @@ import org.apache.spark.sql.types._ """, group = "agg_funcs", since = "1.0.0") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { +case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes + with UnaryLike[Expression] { override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") - override def children: Seq[Expression] = child :: Nil - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 2cc9adb5aa06e..4ca933ff45d02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -45,14 +46,13 @@ import org.apache.spark.sql.types._ * @param child to compute central moments of. */ abstract class CentralMomentAgg(child: Expression, nullOnDivideByZero: Boolean) - extends DeclarativeAggregate with ImplicitCastInputTypes { + extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { /** * The central moment order to be computed. */ protected def momentOrder: Int - override def children: Seq[Expression] = Seq(child) override def nullable: Boolean = true override def dataType: DataType = DoubleType override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 737e8cd3ffa41..d819971478ecf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -30,9 +31,10 @@ import org.apache.spark.sql.types._ * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ abstract class PearsonCorrelation(x: Expression, y: Expression, nullOnDivideByZero: Boolean) - extends DeclarativeAggregate with ImplicitCastInputTypes { + extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - override def children: Seq[Expression] = Seq(x, y) + override def left: Expression = x + override def right: Expression = y override def nullable: Boolean = true override def dataType: DataType = DoubleType override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala index 5bb95ead3f715..53a3fd6b6c23d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate} +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType} @ExpressionDescription( @@ -34,10 +35,12 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long """, group = "agg_funcs", since = "3.0.0") -case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes { +case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes + with UnaryLike[Expression] { + override def prettyName: String = "count_if" - override def children: Seq[Expression] = Seq(predicate) + override def child: Expression = predicate override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 7c4d6ded6559e..160ee92b00447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -27,9 +28,10 @@ import org.apache.spark.sql.types._ * When applied on empty data (i.e., count is zero), it returns NULL. */ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean) - extends DeclarativeAggregate with ImplicitCastInputTypes { + extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - override def children: Seq[Expression] = Seq(x, y) + override def left: Expression = x + override def right: Expression = y override def nullable: Boolean = true override def dataType: DataType = DoubleType override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) @@ -97,8 +99,8 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool group = "agg_funcs", since = "2.0.0") case class CovPopulation( - left: Expression, - right: Expression, + override val left: Expression, + override val right: Expression, nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) extends Covariance(left, right, nullOnDivideByZero) { @@ -122,8 +124,8 @@ case class CovPopulation( group = "agg_funcs", since = "2.0.0") case class CovSample( - left: Expression, - right: Expression, + override val left: Expression, + override val right: Expression, nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) extends Covariance(left, right, nullOnDivideByZero) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 65fd43c924d08..accd15a711503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types._ /** @@ -51,7 +52,7 @@ import org.apache.spark.sql.types._ group = "agg_funcs", since = "2.0.0") case class First(child: Expression, ignoreNulls: Boolean) - extends DeclarativeAggregate with ExpectsInputTypes { + extends DeclarativeAggregate with ExpectsInputTypes with UnaryLike[Expression] { def this(child: Expression) = this(child, false) @@ -59,8 +60,6 @@ case class First(child: Expression, ignoreNulls: Boolean) this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "first")) } - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true // First is not a deterministic function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 1d20387606f61..430c25cee2a93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.HyperLogLogPlusPlusHelper import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -60,7 +61,7 @@ case class HyperLogLogPlusPlus( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate { + extends ImperativeAggregate with UnaryLike[Expression] { def this(child: Expression) = { this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) @@ -82,8 +83,6 @@ case class HyperLogLogPlusPlus( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = false override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 8d17a48a69f6f..e3c427d584489 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ import org.apache.spark.sql.types._ group = "agg_funcs", since = "2.0.0") case class Last(child: Expression, ignoreNulls: Boolean) - extends DeclarativeAggregate with ExpectsInputTypes { + extends DeclarativeAggregate with ExpectsInputTypes with UnaryLike[Expression] { def this(child: Expression) = this(child, false) @@ -58,8 +59,6 @@ case class Last(child: Expression, ignoreNulls: Boolean) this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "last")) } - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true // Last is not a deterministic function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 9bba6604c84ac..42721ea48c7ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -32,9 +33,7 @@ import org.apache.spark.sql.types._ """, group = "agg_funcs", since = "1.0.0") -case class Max(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil +case class Max(child: Expression) extends DeclarativeAggregate with UnaryLike[Expression] { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index 6d3d3dafe16e4..e402bcae144ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions. */ -abstract class MaxMinBy extends DeclarativeAggregate { +abstract class MaxMinBy extends DeclarativeAggregate with BinaryLike[Expression] { def valueExpr: Expression def orderingExpr: Expression @@ -37,7 +38,8 @@ abstract class MaxMinBy extends DeclarativeAggregate { // Used to pick up updated ordering value. protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression - override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil + override def left: Expression = valueExpr + override def right: Expression = orderingExpr override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 1d861aa0dd8cf..84410c7de3229 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -32,9 +33,7 @@ import org.apache.spark.sql.types._ """, group = "agg_funcs", since = "1.0.0") -case class Min(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil +case class Min(child: Expression) extends DeclarativeAggregate with UnaryLike[Expression] { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 1cd22422938b9..b81c523ce32ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -70,7 +71,8 @@ case class Percentile( frequencyExpression : Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes { + extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes + with TernaryLike[Expression] { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, Literal(1L), 0, 0) @@ -99,9 +101,9 @@ case class Percentile( case arrayData: ArrayData => arrayData.toDoubleArray() } - override def children: Seq[Expression] = { - child :: percentageExpression :: frequencyExpression :: Nil - } + override def first: Expression = child + override def second: Expression = percentageExpression + override def third: Expression = frequencyExpression // Returns null for empty inputs override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 17471535873fc..422fcab5bf890 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -21,6 +21,7 @@ import scala.collection.immutable.{HashMap, TreeMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ @@ -73,9 +74,10 @@ case class PivotFirst( valueColumn: Expression, pivotColumnValues: Seq[Any], mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + inputAggBufferOffset: Int = 0) extends ImperativeAggregate with BinaryLike[Expression] { - override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil + override val left: Expression = pivotColumn + override val right: Expression = valueColumn override val nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala index c28ec86ab0b03..50c74f1c49a99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ImplicitCastInputTypes, Literal} +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType} /** Multiply numerical values within an aggregation group */ case class Product(child: Expression) - extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def children: Seq[Expression] = child :: Nil + extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index a29ae2c8b65a1..9cb8097041fed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -37,9 +38,8 @@ import org.apache.spark.sql.types._ """, group = "agg_funcs", since = "1.0.0") -case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def children: Seq[Expression] = child :: Nil +case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes + with UnaryLike[Expression] { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index cb77ded3372a2..5b914c4333687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types._ abstract class UnevaluableBooleanAggBase(arg: Expression) - extends UnevaluableAggregate with ImplicitCastInputTypes { + extends UnevaluableAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { - override def children: Seq[Expression] = arg :: Nil + override def child: Expression = arg override def dataType: DataType = BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala index 573dbd6c3f8c6..25c099525ef81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala @@ -18,16 +18,16 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryArithmetic, BitwiseAnd, BitwiseOr, BitwiseXor, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal} +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType} -abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes { +abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes + with UnaryLike[Expression] { val child: Expression def bitOperator(left: Expression, right: Expression): BinaryArithmetic - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index f1b9630312d55..5f1d03264fa74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -33,12 +34,11 @@ import org.apache.spark.sql.types._ * We have to store all the collected elements in memory, and so notice that too many elements * can cause GC paused and eventually OutOfMemory Errors. */ -abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] { +abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] + with UnaryLike[Expression] { val child: Expression - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = false override def dataType: DataType = ArrayType(child.dataType, false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 11c810d2dd497..e0c6ce7208c94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -62,10 +62,9 @@ case object Complete extends AggregateMode * A place holder expressions used in code-gen, it does not change the corresponding value * in the row. */ -case object NoOp extends Expression with Unevaluable { +case object NoOp extends LeafExpression with Unevaluable { override def nullable: Boolean = true override def dataType: DataType = NullType - override def children: Seq[Expression] = Nil } object AggregateExpression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 17b45bc44a28e..d3fad8cb329c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1428,7 +1428,9 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) - @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval + override def first: Expression = x + override def second: Expression = start + override def third: Expression = length @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 152ed04b013e4..3c016a7a54995 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, Func import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -521,7 +522,9 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E this(child, Literal(","), Literal(":")) } - override def children: Seq[Expression] = Seq(text, pairDelim, keyValueDelim) + override def first: Expression = text + override def second: Expression = pairDelim + override def third: Expression = keyValueDelim override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) @@ -597,7 +600,7 @@ trait StructFieldsOperation { * children, and thereby enable the analyzer to resolve and transform valExpr as necessary. */ case class WithField(name: String, valExpr: Expression) - extends Unevaluable with StructFieldsOperation { + extends Unevaluable with StructFieldsOperation with UnaryLike[Expression] { override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = { val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable), valExpr) @@ -615,7 +618,7 @@ case class WithField(name: String, valExpr: Expression) result.toSeq } - override def children: Seq[Expression] = valExpr :: Nil + override def child: Expression = valExpr override def dataType: DataType = throw new IllegalStateException( "WithField.dataType should not be called.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 7b0be8eb24097..a062dd49a3c92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -35,14 +36,16 @@ import org.apache.spark.sql.types._ group = "conditional_funcs") // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends ComplexTypeMergingExpression { + extends ComplexTypeMergingExpression with TernaryLike[Expression] { @transient override lazy val inputTypesForMerging: Seq[DataType] = { Seq(trueValue.dataType, falseValue.dataType) } - override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def first: Expression = predicate + override def second: Expression = trueValue + override def third: Expression = falseValue override def nullable: Boolean = trueValue.nullable || falseValue.nullable override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 2aaa3aa68f764..0e422c8f0a89b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1603,7 +1603,9 @@ case class MonthsBetween( def this(date1: Expression, date2: Expression, roundOff: Expression) = this(date1, date2, roundOff, None) - override def children: Seq[Expression] = Seq(date1, date2, roundOff) + override def first: Expression = date1 + override def second: Expression = date2 + override def third: Expression = roundOff override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) @@ -2000,7 +2002,9 @@ case class MakeDate( def this(year: Expression, month: Expression, day: Expression) = this(year, month, day, SQLConf.get.ansiEnabled) - override def children: Seq[Expression] = Seq(year, month, day) + override def first: Expression = year + override def second: Expression = month + override def third: Expression = day override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType) override def dataType: DataType = DateType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 17f28da4ad037..c6b67d62d181c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -157,11 +158,11 @@ object GroupingSets { since = "2.0.0", group = "agg_funcs") // scalastyle:on line.size.limit line.contains.tab -case class Grouping(child: Expression) extends Expression with Unevaluable { +case class Grouping(child: Expression) extends Expression with Unevaluable + with UnaryLike[Expression] { @transient override lazy val references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) - override def children: Seq[Expression] = child :: Nil override def dataType: DataType = ByteType override def nullable: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 43281c2dc3c2f..7ddb00b62b89c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -347,7 +347,9 @@ case class Acosh(child: Expression) case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) + override def first: Expression = numExpr + override def second: Expression = fromBaseExpr + override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) override def dataType: DataType = StringType override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 3f360e62d6ecc..5be521683381d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -714,11 +715,14 @@ case class MapObjects private( loopVar: LambdaVariable, lambdaFunction: Expression, inputData: Expression, - customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression { + customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression + with TernaryLike[Expression] { override def nullable: Boolean = inputData.nullable - override def children: Seq[Expression] = Seq(loopVar, lambdaFunction, inputData) + override def first: Expression = loopVar + override def second: Expression = lambdaFunction + override def third: Expression = inputData // The data with UserDefinedType are actually stored with the data type of its sqlType. // When we want to apply MapObjects on it, we have to use it. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 416a6c1ad2c55..bd2d8375782d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -446,7 +446,9 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(StringType) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = str :: regex :: limit :: Nil + override def first: Expression = str + override def second: Expression = regex + override def third: Expression = limit def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)); @@ -646,7 +648,9 @@ abstract class RegExpExtractBase @transient private var pattern: Pattern = _ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + override def first: Expression = subject + override def second: Expression = regexp + override def third: Expression = idx protected def getLastMatcher(s: Any, p: Any): Matcher = { if (p != lastRegex) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 40c28f3878128..c6b7738f8c24d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -517,7 +517,10 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = srcExpr :: searchExpr :: replaceExpr :: Nil + override def first: Expression = srcExpr + override def second: Expression = searchExpr + override def third: Expression = replaceExpr + override def prettyName: String = "replace" } @@ -721,7 +724,9 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil + override def first: Expression = srcExpr + override def second: Expression = matchingExpr + override def third: Expression = replaceExpr override def prettyName: String = "translate" } @@ -1142,7 +1147,9 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) + override def first: Expression = strExpr + override def second: Expression = delimExpr + override def third: Expression = countExpr override def prettyName: String = "substring_index" override def nullSafeEval(str: Any, delim: Any, count: Any): Any = { @@ -1185,7 +1192,9 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) this(substr, str, Literal(1)) } - override def children: Seq[Expression] = substr :: str :: start :: Nil + override def first: Expression = substr + override def second: Expression = str + override def third: Expression = start override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -1275,7 +1284,9 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera this(str, len, Literal(" ")) } - override def children: Seq[Expression] = str :: len :: pad :: Nil + override def first: Expression = str + override def second: Expression = len + override def third: Expression = pad override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) @@ -1317,7 +1328,10 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera this(str, len, Literal(" ")) } - override def children: Seq[Expression] = str :: len :: pad :: Nil + override def first: Expression = str + override def second: Expression = len + override def third: Expression = pad + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) @@ -1728,7 +1742,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType) - override def children: Seq[Expression] = str :: pos :: len :: Nil + override def first: Expression = str + override def second: Expression = pos + override def third: Expression = len override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { str.dataType match { @@ -2439,7 +2455,7 @@ case class Sentences( str: Expression, language: Expression = Literal(""), country: Expression = Literal("")) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { def this(str: Expression) = this(str, Literal(""), Literal("")) def this(str: Expression, language: Expression) = this(str, language, Literal("")) @@ -2448,7 +2464,9 @@ case class Sentences( override def dataType: DataType = ArrayType(ArrayType(StringType, containsNull = false), containsNull = false) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = str :: language :: country :: Nil + override def first: Expression = str + override def second: Expression = language + override def third: Expression = country override def eval(input: InternalRow): Any = { val string = str.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 6e6d99a472219..d45614fa292c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExcept import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, UnaryLike} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -141,8 +142,7 @@ case object RangeFrame extends FrameType { /** * The trait used to represent special boundaries used in a window frame. */ -sealed trait SpecialFrameBoundary extends Expression with Unevaluable { - override def children: Seq[Expression] = Nil +sealed trait SpecialFrameBoundary extends LeafExpression with Unevaluable { override def dataType: DataType = NullType override def nullable: Boolean = false } @@ -165,13 +165,12 @@ case object CurrentRow extends SpecialFrameBoundary { * Represents a window frame. */ sealed trait WindowFrame extends Expression with Unevaluable { - override def children: Seq[Expression] = Nil override def dataType: DataType = throw QueryExecutionErrors.dataTypeOperationUnsupportedError override def nullable: Boolean = false } /** Used as a placeholder when a frame specification is not defined. */ -case object UnspecifiedFrame extends WindowFrame +case object UnspecifiedFrame extends WindowFrame with LeafLike[Expression] /** * A specified Window Frame. The val lower/upper can be either a foldable [[Expression]] or a @@ -181,9 +180,10 @@ case class SpecifiedWindowFrame( frameType: FrameType, lower: Expression, upper: Expression) - extends WindowFrame { + extends WindowFrame with BinaryLike[Expression] { - override def children: Seq[Expression] = lower :: upper :: Nil + override def left: Expression = lower + override def right: Expression = upper lazy val valueBoundary: Seq[Expression] = children.filterNot(_.isInstanceOf[SpecialFrameBoundary]) @@ -279,9 +279,11 @@ case class UnresolvedWindowExpression( case class WindowExpression( windowFunction: Expression, - windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable + with BinaryLike[Expression] { - override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil + override def left: Expression = windowFunction + override def right: Expression = windowSpec override def dataType: DataType = windowFunction.dataType override def nullable: Boolean = windowFunction.nullable @@ -373,8 +375,6 @@ trait OffsetWindowFunction extends WindowFunction { sealed abstract class FrameLessOffsetWindowFunction extends OffsetWindowFunction with Unevaluable with ImplicitCastInputTypes { - override def children: Seq[Expression] = Seq(input, offset, default) - /* * The result of an OffsetWindowFunction is dependent on the frame in which the * OffsetWindowFunction is executed, the input expression and the default expression. Even when @@ -444,7 +444,7 @@ sealed abstract class FrameLessOffsetWindowFunction // scalastyle:on line.size.limit line.contains.tab case class Lead( input: Expression, offset: Expression, default: Expression, ignoreNulls: Boolean) - extends FrameLessOffsetWindowFunction { + extends FrameLessOffsetWindowFunction with TernaryLike[Expression] { def this(input: Expression, offset: Expression, default: Expression) = this(input, offset, default, false) @@ -454,6 +454,10 @@ case class Lead( def this(input: Expression) = this(input, Literal(1)) def this() = this(Literal(null)) + + override def first: Expression = input + override def second: Expression = offset + override def third: Expression = default } /** @@ -490,7 +494,7 @@ case class Lead( // scalastyle:on line.size.limit line.contains.tab case class Lag( input: Expression, inputOffset: Expression, default: Expression, ignoreNulls: Boolean) - extends FrameLessOffsetWindowFunction { + extends FrameLessOffsetWindowFunction with TernaryLike[Expression] { def this(input: Expression, inputOffset: Expression, default: Expression) = this(input, inputOffset, default, false) @@ -501,12 +505,14 @@ case class Lag( def this() = this(Literal(null)) - override def children: Seq[Expression] = Seq(input, inputOffset, default) - override val offset: Expression = UnaryMinus(inputOffset) match { case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) case o => o } + + override def first: Expression = input + override def second: Expression = inputOffset + override def third: Expression = default } abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { @@ -519,7 +525,6 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF } abstract class RowNumberLike extends AggregateWindowFunction { - override def children: Seq[Expression] = Nil protected val zero = Literal(0) protected val one = Literal(1) protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() @@ -567,7 +572,7 @@ object SizeBasedWindowFunction { since = "2.0.0", group = "window_funcs") // scalastyle:on line.size.limit line.contains.tab -case class RowNumber() extends RowNumberLike { +case class RowNumber() extends RowNumberLike with LeafLike[Expression] { override val evaluateExpression = rowNumber override def prettyName: String = "row_number" } @@ -596,7 +601,7 @@ case class RowNumber() extends RowNumberLike { since = "2.0.0", group = "window_funcs") // scalastyle:on line.size.limit line.contains.tab -case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction with LeafLike[Expression] { override def dataType: DataType = DoubleType // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must // return the same value for equal values in the partition. @@ -634,13 +639,15 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { group = "window_funcs") // scalastyle:on line.size.limit line.contains.tab case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean) - extends AggregateWindowFunction with OffsetWindowFunction with ImplicitCastInputTypes { + extends AggregateWindowFunction with OffsetWindowFunction with ImplicitCastInputTypes + with BinaryLike[Expression] { def this(child: Expression, offset: Expression) = this(child, offset, false) override lazy val default = Literal.create(null, input.dataType) - override def children: Seq[Expression] = input :: offset :: Nil + override def left: Expression = input + override def right: Expression = offset override val frame: WindowFrame = UnspecifiedFrame @@ -734,10 +741,12 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean) since = "2.0.0", group = "window_funcs") // scalastyle:on line.size.limit line.contains.tab -case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { +case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction + with UnaryLike[Expression] { + def this() = this(Literal(1)) - override def children: Seq[Expression] = Seq(buckets) + override def child: Expression = buckets // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant // for each partition. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 89bd865391b5a..94ead5e3edee9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, UnaryLike} /** * A logical node that represents a non-query command to be executed by the system. For example, @@ -27,9 +28,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} trait Command extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty override def producedAttributes: AttributeSet = outputSet - override def children: Seq[LogicalPlan] = Seq.empty // Commands are eagerly executed. They will be converted to LocalRelation after the DataFrame // is created. That said, the statistics of a command is useless. Here we just return a dummy // statistics to avoid unnecessary statistics calculation of command's children. override def stats: Statistics = Statistics.DUMMY } + +trait LeafCommand extends Command with LeafLike[LogicalPlan] +trait UnaryCommand extends Command with UnaryLike[LogicalPlan] +trait BinaryCommand extends Command with BinaryLike[LogicalPlan] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0eff5558627b6..7129c6984cf3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, UnaryLike} import org.apache.spark.sql.types.StructType @@ -158,8 +159,7 @@ abstract class LogicalPlan /** * A logical plan node with no children. */ -abstract class LeafNode extends LogicalPlan { - override final def children: Seq[LogicalPlan] = Nil +trait LeafNode extends LogicalPlan with LeafLike[LogicalPlan] { override def producedAttributes: AttributeSet = outputSet /** Leaf nodes that can survive analysis must define their own statistics. */ @@ -169,11 +169,7 @@ abstract class LeafNode extends LogicalPlan { /** * A logical plan node with single child. */ -abstract class UnaryNode extends LogicalPlan { - def child: LogicalPlan - - override final def children: Seq[LogicalPlan] = child :: Nil - +trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] { /** * Generates all valid constraints including an set of aliased constraints by replacing the * original constraint expressions with the corresponding alias @@ -202,12 +198,7 @@ abstract class UnaryNode extends LogicalPlan { /** * A logical plan node with a left and right child. */ -abstract class BinaryNode extends LogicalPlan { - def left: LogicalPlan - def right: LogicalPlan - - override final def children: Seq[LogicalPlan] = Seq(left, right) -} +trait BinaryNode extends LogicalPlan with BinaryLike[LogicalPlan] abstract class OrderPreservingUnaryNode extends UnaryNode { override final def outputOrdering: Seq[SortOrder] = child.outputOrdering diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index cc6e387d0f600..d600c15004d1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis.ViewType import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource} import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.{DataType, StructType} @@ -47,11 +48,12 @@ abstract class ParsedStatement extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty - final override lazy val resolved = false } +trait LeafParsedStatement extends ParsedStatement with LeafLike[LogicalPlan] +trait UnaryParsedStatement extends ParsedStatement with UnaryLike[LogicalPlan] + /** * Type to keep track of Hive serde info */ @@ -144,7 +146,7 @@ case class CreateTableStatement( comment: Option[String], serde: Option[SerdeInfo], external: Boolean, - ifNotExists: Boolean) extends ParsedStatement + ifNotExists: Boolean) extends LeafParsedStatement /** * A CREATE TABLE AS SELECT command, as parsed from SQL. @@ -162,9 +164,9 @@ case class CreateTableAsSelectStatement( writeOptions: Map[String, String], serde: Option[SerdeInfo], external: Boolean, - ifNotExists: Boolean) extends ParsedStatement { + ifNotExists: Boolean) extends UnaryParsedStatement { - override def children: Seq[LogicalPlan] = Seq(asSelect) + override def child: LogicalPlan = asSelect } /** @@ -179,10 +181,7 @@ case class CreateViewStatement( child: LogicalPlan, allowExisting: Boolean, replace: Boolean, - viewType: ViewType) extends ParsedStatement { - - override def children: Seq[LogicalPlan] = Seq(child) -} + viewType: ViewType) extends UnaryParsedStatement /** * A REPLACE TABLE command, as parsed from SQL. @@ -201,7 +200,7 @@ case class ReplaceTableStatement( location: Option[String], comment: Option[String], serde: Option[SerdeInfo], - orCreate: Boolean) extends ParsedStatement + orCreate: Boolean) extends LeafParsedStatement /** * A REPLACE TABLE AS SELECT command, as parsed from SQL. @@ -218,9 +217,9 @@ case class ReplaceTableAsSelectStatement( comment: Option[String], writeOptions: Map[String, String], serde: Option[SerdeInfo], - orCreate: Boolean) extends ParsedStatement { + orCreate: Boolean) extends UnaryParsedStatement { - override def children: Seq[LogicalPlan] = Seq(asSelect) + override def child: LogicalPlan = asSelect } @@ -239,11 +238,11 @@ case class QualifiedColType( */ case class AlterTableAddColumnsStatement( tableName: Seq[String], - columnsToAdd: Seq[QualifiedColType]) extends ParsedStatement + columnsToAdd: Seq[QualifiedColType]) extends LeafParsedStatement case class AlterTableReplaceColumnsStatement( tableName: Seq[String], - columnsToAdd: Seq[QualifiedColType]) extends ParsedStatement + columnsToAdd: Seq[QualifiedColType]) extends LeafParsedStatement /** * ALTER TABLE ... CHANGE COLUMN command, as parsed from SQL. @@ -254,7 +253,7 @@ case class AlterTableAlterColumnStatement( dataType: Option[DataType], nullable: Option[Boolean], comment: Option[String], - position: Option[ColumnPosition]) extends ParsedStatement + position: Option[ColumnPosition]) extends LeafParsedStatement /** * ALTER TABLE ... RENAME COLUMN command, as parsed from SQL. @@ -262,14 +261,14 @@ case class AlterTableAlterColumnStatement( case class AlterTableRenameColumnStatement( tableName: Seq[String], column: Seq[String], - newName: String) extends ParsedStatement + newName: String) extends LeafParsedStatement /** * ALTER TABLE ... DROP COLUMNS command, as parsed from SQL. */ case class AlterTableDropColumnsStatement( tableName: Seq[String], - columnsToDrop: Seq[Seq[String]]) extends ParsedStatement + columnsToDrop: Seq[Seq[String]]) extends LeafParsedStatement /** * An INSERT INTO statement, as parsed from SQL. @@ -293,14 +292,14 @@ case class InsertIntoStatement( userSpecifiedCols: Seq[String], query: LogicalPlan, overwrite: Boolean, - ifPartitionNotExists: Boolean) extends ParsedStatement { + ifPartitionNotExists: Boolean) extends UnaryParsedStatement { require(overwrite || !ifPartitionNotExists, "IF NOT EXISTS is only valid in INSERT OVERWRITE") require(partitionSpec.values.forall(_.nonEmpty) || !ifPartitionNotExists, "IF NOT EXISTS is only valid with static partitions") - override def children: Seq[LogicalPlan] = query :: Nil + override def child: LogicalPlan = query } /** @@ -309,17 +308,17 @@ case class InsertIntoStatement( case class CreateNamespaceStatement( namespace: Seq[String], ifNotExists: Boolean, - properties: Map[String, String]) extends ParsedStatement + properties: Map[String, String]) extends LeafParsedStatement /** * A USE statement, as parsed from SQL. */ -case class UseStatement(isNamespaceSet: Boolean, nameParts: Seq[String]) extends ParsedStatement +case class UseStatement(isNamespaceSet: Boolean, nameParts: Seq[String]) extends LeafParsedStatement /** * A SHOW CURRENT NAMESPACE statement, as parsed from SQL */ -case class ShowCurrentNamespaceStatement() extends ParsedStatement +case class ShowCurrentNamespaceStatement() extends LeafParsedStatement /** * CREATE FUNCTION statement, as parsed from SQL @@ -330,4 +329,4 @@ case class CreateFunctionStatement( resources: Seq[FunctionResource], isTemp: Boolean, ignoreIfExists: Boolean, - replace: Boolean) extends ParsedStatement + replace: Boolean) extends LeafParsedStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 938c23a51128e..3c3d642f7d36d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PartitionSpec, Unr import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange} @@ -31,12 +32,12 @@ import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, Strin /** * Base trait for DataSourceV2 write commands */ -trait V2WriteCommand extends Command { +trait V2WriteCommand extends UnaryCommand { def table: NamedRelation def query: LogicalPlan def isByName: Boolean - override def children: Seq[LogicalPlan] = Seq(query) + override def child: LogicalPlan = query override lazy val resolved: Boolean = table.resolved && query.resolved && outputResolved @@ -59,9 +60,10 @@ trait V2WriteCommand extends Command { def withNewTable(newTable: NamedRelation): V2WriteCommand } -trait V2PartitionCommand extends Command { +trait V2PartitionCommand extends UnaryCommand { def table: LogicalPlan def allowPartialPartitionSpec: Boolean = false + override def child: LogicalPlan = table } /** @@ -189,7 +191,7 @@ case class CreateV2Table( tableSchema: StructType, partitioning: Seq[Transform], properties: Map[String, String], - ignoreIfExists: Boolean) extends Command with V2CreateTablePlan { + ignoreIfExists: Boolean) extends LeafCommand with V2CreateTablePlan { override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } @@ -205,10 +207,10 @@ case class CreateTableAsSelect( query: LogicalPlan, properties: Map[String, String], writeOptions: Map[String, String], - ignoreIfExists: Boolean) extends Command with V2CreateTablePlan { + ignoreIfExists: Boolean) extends UnaryCommand with V2CreateTablePlan { override def tableSchema: StructType = query.schema - override def children: Seq[LogicalPlan] = Seq(query) + override def child: LogicalPlan = query override lazy val resolved: Boolean = childrenResolved && { // the table schema is created from the query schema, so the only resolution needed is to check @@ -236,7 +238,7 @@ case class ReplaceTable( tableSchema: StructType, partitioning: Seq[Transform], properties: Map[String, String], - orCreate: Boolean) extends Command with V2CreateTablePlan { + orCreate: Boolean) extends LeafCommand with V2CreateTablePlan { override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } @@ -255,10 +257,10 @@ case class ReplaceTableAsSelect( query: LogicalPlan, properties: Map[String, String], writeOptions: Map[String, String], - orCreate: Boolean) extends Command with V2CreateTablePlan { + orCreate: Boolean) extends UnaryCommand with V2CreateTablePlan { override def tableSchema: StructType = query.schema - override def children: Seq[LogicalPlan] = Seq(query) + override def child: LogicalPlan = query override lazy val resolved: Boolean = childrenResolved && { // the table schema is created from the query schema, so the only resolution needed is to check @@ -279,7 +281,7 @@ case class CreateNamespace( catalog: SupportsNamespaces, namespace: Seq[String], ifNotExists: Boolean, - properties: Map[String, String]) extends Command + properties: Map[String, String]) extends LeafCommand /** * The logical plan of the DROP NAMESPACE command. @@ -287,8 +289,8 @@ case class CreateNamespace( case class DropNamespace( namespace: LogicalPlan, ifExists: Boolean, - cascade: Boolean) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) + cascade: Boolean) extends UnaryCommand { + override def child: LogicalPlan = namespace } /** @@ -297,9 +299,8 @@ case class DropNamespace( case class DescribeNamespace( namespace: LogicalPlan, extended: Boolean, - override val output: Seq[Attribute] = DescribeNamespace.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) - + override val output: Seq[Attribute] = DescribeNamespace.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = namespace } object DescribeNamespace { @@ -316,8 +317,8 @@ object DescribeNamespace { */ case class SetNamespaceProperties( namespace: LogicalPlan, - properties: Map[String, String]) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) + properties: Map[String, String]) extends UnaryCommand { + override def child: LogicalPlan = namespace } /** @@ -325,8 +326,8 @@ case class SetNamespaceProperties( */ case class SetNamespaceLocation( namespace: LogicalPlan, - location: String) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) + location: String) extends UnaryCommand { + override def child: LogicalPlan = namespace } /** @@ -335,8 +336,8 @@ case class SetNamespaceLocation( case class ShowNamespaces( namespace: LogicalPlan, pattern: Option[String], - override val output: Seq[Attribute] = ShowNamespaces.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) + override val output: Seq[Attribute] = ShowNamespaces.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = namespace } object ShowNamespaces { @@ -352,8 +353,8 @@ case class DescribeRelation( relation: LogicalPlan, partitionSpec: TablePartitionSpec, isExtended: Boolean, - override val output: Seq[Attribute] = DescribeRelation.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = Seq(relation) + override val output: Seq[Attribute] = DescribeRelation.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = relation } object DescribeRelation { @@ -367,8 +368,8 @@ case class DescribeColumn( relation: LogicalPlan, column: Expression, isExtended: Boolean, - override val output: Seq[Attribute] = DescribeColumn.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = Seq(relation) + override val output: Seq[Attribute] = DescribeColumn.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = relation } object DescribeColumn { @@ -380,8 +381,8 @@ object DescribeColumn { */ case class DeleteFromTable( table: LogicalPlan, - condition: Option[Expression]) extends Command with SupportsSubquery { - override def children: Seq[LogicalPlan] = table :: Nil + condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { + override def child: LogicalPlan = table } /** @@ -390,8 +391,8 @@ case class DeleteFromTable( case class UpdateTable( table: LogicalPlan, assignments: Seq[Assignment], - condition: Option[Expression]) extends Command with SupportsSubquery { - override def children: Seq[LogicalPlan] = table :: Nil + condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { + override def child: LogicalPlan = table } /** @@ -402,9 +403,10 @@ case class MergeIntoTable( sourceTable: LogicalPlan, mergeCondition: Expression, matchedActions: Seq[MergeAction], - notMatchedActions: Seq[MergeAction]) extends Command with SupportsSubquery { - override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable) + notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery { def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty + override def left: LogicalPlan = targetTable + override def right: LogicalPlan = sourceTable } sealed abstract class MergeAction extends Expression with Unevaluable { @@ -428,10 +430,12 @@ case class InsertAction( override def children: Seq[Expression] = condition.toSeq ++ assignments } -case class Assignment(key: Expression, value: Expression) extends Expression with Unevaluable { +case class Assignment(key: Expression, value: Expression) extends Expression + with Unevaluable with BinaryLike[Expression] { override def nullable: Boolean = false override def dataType: DataType = throw new UnresolvedException("nullable") - override def children: Seq[Expression] = key :: value :: Nil + override def left: Expression = key + override def right: Expression = value } /** @@ -448,16 +452,14 @@ case class Assignment(key: Expression, value: Expression) extends Expression wit case class DropTable( child: LogicalPlan, ifExists: Boolean, - purge: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + purge: Boolean) extends UnaryCommand /** * The logical plan for no-op command handling non-existing table. */ case class NoopCommand( commandName: String, - multipartIdentifier: Seq[String]) extends Command + multipartIdentifier: Seq[String]) extends LeafCommand /** * The logical plan of the ALTER TABLE command. @@ -466,7 +468,7 @@ case class AlterTable( catalog: TableCatalog, ident: Identifier, table: NamedRelation, - changes: Seq[TableChange]) extends Command { + changes: Seq[TableChange]) extends LeafCommand { override lazy val resolved: Boolean = table.resolved && { changes.forall { @@ -497,9 +499,7 @@ case class AlterTable( case class RenameTable( child: LogicalPlan, newName: Seq[String], - isView: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + isView: Boolean) extends UnaryCommand /** * The logical plan of the SHOW TABLES command. @@ -507,8 +507,8 @@ case class RenameTable( case class ShowTables( namespace: LogicalPlan, pattern: Option[String], - override val output: Seq[Attribute] = ShowTables.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) + override val output: Seq[Attribute] = ShowTables.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = namespace } object ShowTables { @@ -525,8 +525,8 @@ case class ShowTableExtended( namespace: LogicalPlan, pattern: String, partitionSpec: Option[PartitionSpec], - override val output: Seq[Attribute] = ShowTableExtended.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = namespace :: Nil + override val output: Seq[Attribute] = ShowTableExtended.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = namespace } object ShowTableExtended { @@ -546,8 +546,8 @@ object ShowTableExtended { case class ShowViews( namespace: LogicalPlan, pattern: Option[String], - override val output: Seq[Attribute] = ShowViews.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) + override val output: Seq[Attribute] = ShowViews.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = namespace } object ShowViews { @@ -563,19 +563,17 @@ object ShowViews { case class SetCatalogAndNamespace( catalogManager: CatalogManager, catalogName: Option[String], - namespace: Option[Seq[String]]) extends Command + namespace: Option[Seq[String]]) extends LeafCommand /** * The logical plan of the REFRESH TABLE command. */ -case class RefreshTable(child: LogicalPlan) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} +case class RefreshTable(child: LogicalPlan) extends UnaryCommand /** * The logical plan of the SHOW CURRENT NAMESPACE command. */ -case class ShowCurrentNamespace(catalogManager: CatalogManager) extends Command { +case class ShowCurrentNamespace(catalogManager: CatalogManager) extends LeafCommand { override val output: Seq[Attribute] = Seq( AttributeReference("catalog", StringType, nullable = false)(), AttributeReference("namespace", StringType, nullable = false)()) @@ -587,8 +585,8 @@ case class ShowCurrentNamespace(catalogManager: CatalogManager) extends Command case class ShowTableProperties( table: LogicalPlan, propertyKey: Option[String], - override val output: Seq[Attribute] = ShowTableProperties.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = table :: Nil + override val output: Seq[Attribute] = ShowTableProperties.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = table } object ShowTableProperties { @@ -607,9 +605,7 @@ object ShowTableProperties { * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment. * */ -case class CommentOnNamespace(child: LogicalPlan, comment: String) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} +case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand /** * The logical plan that defines or changes the comment of an TABLE for v2 catalogs. @@ -621,23 +617,17 @@ case class CommentOnNamespace(child: LogicalPlan, comment: String) extends Comma * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment. * */ -case class CommentOnTable(child: LogicalPlan, comment: String) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} +case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand /** * The logical plan of the REFRESH FUNCTION command. */ -case class RefreshFunction(child: LogicalPlan) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} +case class RefreshFunction(child: LogicalPlan) extends UnaryCommand /** * The logical plan of the DESCRIBE FUNCTION command. */ -case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} +case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand /** * The logical plan of the DROP FUNCTION command. @@ -645,9 +635,7 @@ case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends Com case class DropFunction( child: LogicalPlan, ifExists: Boolean, - isTemp: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + isTemp: Boolean) extends UnaryCommand /** * The logical plan of the SHOW FUNCTIONS command. @@ -673,17 +661,15 @@ object ShowFunctions { case class AnalyzeTable( child: LogicalPlan, partitionSpec: Map[String, Option[String]], - noScan: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + noScan: Boolean) extends UnaryCommand /** * The logical plan of the ANALYZE TABLES command. */ case class AnalyzeTables( namespace: LogicalPlan, - noScan: Boolean) extends Command { - override def children: Seq[LogicalPlan] = Seq(namespace) + noScan: Boolean) extends UnaryCommand { + override def child: LogicalPlan = namespace } /** @@ -692,10 +678,9 @@ case class AnalyzeTables( case class AnalyzeColumn( child: LogicalPlan, columnNames: Option[Seq[String]], - allColumns: Boolean) extends Command { + allColumns: Boolean) extends UnaryCommand { require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " + "mutually exclusive. Only one of them should be specified.") - override def children: Seq[LogicalPlan] = child :: Nil } /** @@ -710,9 +695,7 @@ case class AnalyzeColumn( case class AddPartitions( table: LogicalPlan, parts: Seq[PartitionSpec], - ifNotExists: Boolean) extends V2PartitionCommand { - override def children: Seq[LogicalPlan] = table :: Nil -} + ifNotExists: Boolean) extends V2PartitionCommand /** * The logical plan of the ALTER TABLE DROP PARTITION command. @@ -730,9 +713,7 @@ case class DropPartitions( table: LogicalPlan, parts: Seq[PartitionSpec], ifExists: Boolean, - purge: Boolean) extends V2PartitionCommand { - override def children: Seq[LogicalPlan] = table :: Nil -} + purge: Boolean) extends V2PartitionCommand /** * The logical plan of the ALTER TABLE ... RENAME TO PARTITION command. @@ -740,16 +721,12 @@ case class DropPartitions( case class RenamePartitions( table: LogicalPlan, from: PartitionSpec, - to: PartitionSpec) extends V2PartitionCommand { - override def children: Seq[LogicalPlan] = table :: Nil -} + to: PartitionSpec) extends V2PartitionCommand /** * The logical plan of the ALTER TABLE ... RECOVER PARTITIONS command. */ -case class RecoverPartitions(child: LogicalPlan) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} +case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand /** * The logical plan of the LOAD DATA INTO TABLE command. @@ -759,9 +736,7 @@ case class LoadData( path: String, isLocal: Boolean, isOverwrite: Boolean, - partition: Option[TablePartitionSpec]) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + partition: Option[TablePartitionSpec]) extends UnaryCommand /** * The logical plan of the SHOW CREATE TABLE command. @@ -769,9 +744,7 @@ case class LoadData( case class ShowCreateTable( child: LogicalPlan, asSerde: Boolean = false, - override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends UnaryCommand object ShowCreateTable { def getoutputAttrs: Seq[Attribute] = { @@ -785,9 +758,7 @@ object ShowCreateTable { case class ShowColumns( child: LogicalPlan, namespace: Option[Seq[String]], - override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends UnaryCommand object ShowColumns { def getOutputAttrs: Seq[Attribute] = { @@ -798,8 +769,8 @@ object ShowColumns { /** * The logical plan of the TRUNCATE TABLE command. */ -case class TruncateTable(table: LogicalPlan) extends Command { - override def children: Seq[LogicalPlan] = table :: Nil +case class TruncateTable(table: LogicalPlan) extends UnaryCommand { + override def child: LogicalPlan = table } /** @@ -808,7 +779,6 @@ case class TruncateTable(table: LogicalPlan) extends Command { case class TruncatePartition( table: LogicalPlan, partitionSpec: PartitionSpec) extends V2PartitionCommand { - override def children: Seq[LogicalPlan] = table :: Nil override def allowPartialPartitionSpec: Boolean = true } @@ -820,7 +790,6 @@ case class ShowPartitions( pattern: Option[PartitionSpec], override val output: Seq[Attribute] = ShowPartitions.getOutputAttrs) extends V2PartitionCommand { - override def children: Seq[LogicalPlan] = table :: Nil override def allowPartialPartitionSpec: Boolean = true } @@ -835,9 +804,7 @@ object ShowPartitions { */ case class DropView( child: LogicalPlan, - ifExists: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + ifExists: Boolean) extends UnaryCommand /** * The logical plan of the MSCK REPAIR TABLE command. @@ -845,9 +812,7 @@ case class DropView( case class RepairTable( child: LogicalPlan, enableAddPartitions: Boolean, - enableDropPartitions: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + enableDropPartitions: Boolean) extends UnaryCommand /** * The logical plan of the ALTER VIEW ... AS command. @@ -855,8 +820,9 @@ case class RepairTable( case class AlterViewAs( child: LogicalPlan, originalText: String, - query: LogicalPlan) extends Command { - override def children: Seq[LogicalPlan] = child :: query :: Nil + query: LogicalPlan) extends BinaryCommand { + override def left: LogicalPlan = child + override def right: LogicalPlan = query } /** @@ -864,9 +830,7 @@ case class AlterViewAs( */ case class SetViewProperties( child: LogicalPlan, - properties: Map[String, String]) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + properties: Map[String, String]) extends UnaryCommand /** * The logical plan of the ALTER VIEW ... UNSET TBLPROPERTIES command. @@ -874,9 +838,7 @@ case class SetViewProperties( case class UnsetViewProperties( child: LogicalPlan, propertyKeys: Seq[String], - ifExists: Boolean) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + ifExists: Boolean) extends UnaryCommand /** * The logical plan of the ALTER TABLE ... SET [SERDE|SERDEPROPERTIES] command. @@ -885,9 +847,7 @@ case class SetTableSerDeProperties( child: LogicalPlan, serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], - partitionSpec: Option[TablePartitionSpec]) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil -} + partitionSpec: Option[TablePartitionSpec]) extends UnaryCommand /** * The logical plan of the CACHE TABLE command. @@ -896,7 +856,7 @@ case class CacheTable( table: LogicalPlan, multipartIdentifier: Seq[String], isLazy: Boolean, - options: Map[String, String]) extends Command + options: Map[String, String]) extends LeafCommand /** * The logical plan of the CACHE TABLE ... AS SELECT command. @@ -906,7 +866,7 @@ case class CacheTableAsSelect( plan: LogicalPlan, originalText: String, isLazy: Boolean, - options: Map[String, String]) extends Command + options: Map[String, String]) extends LeafCommand /** * The logical plan of the UNCACHE TABLE command. @@ -914,7 +874,7 @@ case class CacheTableAsSelect( case class UncacheTable( table: LogicalPlan, ifExists: Boolean, - isTempView: Boolean = false) extends Command + isTempView: Boolean = false) extends LeafCommand /** * The logical plan of the ALTER TABLE ... SET LOCATION command. @@ -922,8 +882,8 @@ case class UncacheTable( case class SetTableLocation( table: LogicalPlan, partitionSpec: Option[TablePartitionSpec], - location: String) extends Command { - override def children: Seq[LogicalPlan] = table :: Nil + location: String) extends UnaryCommand { + override def child: LogicalPlan = table } /** @@ -931,8 +891,8 @@ case class SetTableLocation( */ case class SetTableProperties( table: LogicalPlan, - properties: Map[String, String]) extends Command { - override def children: Seq[LogicalPlan] = table :: Nil + properties: Map[String, String]) extends UnaryCommand { + override def child: LogicalPlan = table } /** @@ -941,6 +901,6 @@ case class SetTableProperties( case class UnsetTableProperties( table: LogicalPlan, propertyKeys: Seq[String], - ifExists: Boolean) extends Command { - override def children: Seq[LogicalPlan] = table :: Nil + ifExists: Boolean) extends UnaryCommand { + override def child: LogicalPlan = table } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala index 9571cf4618737..990ae302dbbee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.streaming import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.streaming.OutputMode @@ -31,12 +31,13 @@ case class WriteToStream( sink: Table, outputMode: OutputMode, deleteCheckpointOnStop: Boolean, - inputQuery: LogicalPlan) extends LogicalPlan { + inputQuery: LogicalPlan) extends UnaryNode { override def isStreaming: Boolean = true override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = inputQuery :: Nil + override def child: LogicalPlan = inputQuery + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala index c1e2f017cc92f..34a4c13efb62e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.streaming import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.streaming.OutputMode @@ -50,12 +50,12 @@ case class WriteToStreamStatement( outputMode: OutputMode, hadoopConf: Configuration, isContinuousTrigger: Boolean, - inputQuery: LogicalPlan) extends LogicalPlan { + inputQuery: LogicalPlan) extends UnaryNode { override def isStreaming: Boolean = true override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = inputQuery :: Nil + override def child: LogicalPlan = inputQuery } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 00cd4e9077109..06bb7baed9ce5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -828,3 +828,25 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case _ => false } } + +trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] => + override final def children: Seq[T] = Nil +} + +trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] => + def child: T + @transient override final lazy val children: Seq[T] = child :: Nil +} + +trait BinaryLike[T <: TreeNode[T]] { self: TreeNode[T] => + def left: T + def right: T + @transient override final lazy val children: Seq[T] = left :: right :: Nil +} + +trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] => + def first: T + def second: T + def third: T + @transient override final lazy val children: Seq[T] = first :: second :: third :: Nil +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index e3f7815a63cc5..71993e1a369ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} /** A dummy command for testing unsupported operations. */ -case class DummyCommand() extends Command +case class DummyCommand() extends LeafCommand class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index d3124a4bb5002..8096062f71a23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, Quali import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, Project, Range, SubqueryAlias, View} +import org.apache.spark.sql.catalyst.plans.logical.{LeafCommand, LogicalPlan, Project, Range, SubqueryAlias, View} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -1675,7 +1675,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { } test("expire table relation cache if TTL is configured") { - case class TestCommand() extends Command + case class TestCommand() extends LeafCommand val conf = new SQLConf() conf.setConf(StaticSQLConf.METADATA_CACHE_TTL_SECONDS, 1L) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7308fe30f3579..40bf094856bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch @@ -513,8 +513,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } -trait LeafExecNode extends SparkPlan { - override final def children: Seq[SparkPlan] = Nil +trait LeafExecNode extends SparkPlan with LeafLike[SparkPlan] { + override def producedAttributes: AttributeSet = outputSet override def verboseStringWithOperatorId(): String = { val argumentString = argString(conf.maxToStringFields) @@ -542,10 +542,8 @@ object UnaryExecNode { } } -trait UnaryExecNode extends SparkPlan { - def child: SparkPlan +trait UnaryExecNode extends SparkPlan with UnaryLike[SparkPlan] { - override final def children: Seq[SparkPlan] = child :: Nil override def verboseStringWithOperatorId(): String = { val argumentString = argString(conf.maxToStringFields) val inputStr = s"${ExplainUtils.generateFieldString("Input", child.output)}" @@ -565,11 +563,8 @@ trait UnaryExecNode extends SparkPlan { } } -trait BinaryExecNode extends SparkPlan { - def left: SparkPlan - def right: SparkPlan +trait BinaryExecNode extends SparkPlan with BinaryLike[SparkPlan] { - override final def children: Seq[SparkPlan] = Seq(left, right) override def verboseStringWithOperatorId(): String = { val argumentString = argString(conf.maxToStringFields) val leftOutputStr = s"${ExplainUtils.generateFieldString("Left output", left.output)}" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 57df4e4614fd7..ac6e2ba9eba4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LeafCommand, LogicalPlan} import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetric @@ -37,7 +37,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends Command { +trait RunnableCommand extends LeafCommand { // The map used to record the metrics of running the command. This will be passed to // `ExecutedCommand` during query planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index abb2125841b59..2ed0e06807bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform @@ -44,8 +44,8 @@ import org.apache.spark.util.{LongAccumulator, Utils} */ @deprecated("Use specific logical plans like AppendData instead", "2.4.0") case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) - extends LogicalPlan { - override def children: Seq[LogicalPlan] = Seq(query) + extends UnaryNode { + override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index cecb2843fc3b0..1923fc969801e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.connector.write.streaming.StreamingWrite /** * The logical plan for writing data in a continuous stream. */ case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan) - extends LogicalPlan { - override def children: Seq[LogicalPlan] = Seq(query) + extends UnaryNode { + override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index ef1115e6d9e01..4bacd71a55ec1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 @@ -29,8 +29,8 @@ import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 * to [[WriteToDataSourceV2]] with [[MicroBatchWrite]] before execution. */ case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan) - extends LogicalPlan { - override def children: Seq[LogicalPlan] = Seq(query) + extends UnaryNode { + override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil def createPlan(batchId: Long): WriteToDataSourceV2 = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 0080b73575de1..9c950fd8033a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} @@ -62,10 +63,9 @@ object ExecSubqueryExpression { case class ScalarSubquery( plan: BaseSubqueryExec, exprId: ExprId) - extends ExecSubqueryExpression { + extends ExecSubqueryExpression with LeafLike[Expression] { override def dataType: DataType = plan.schema.fields.head.dataType - override def children: Seq[Expression] = Nil override def nullable: Boolean = true override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: BaseSubqueryExec): ScalarSubquery = copy(plan = query) @@ -114,13 +114,13 @@ case class InSubqueryExec( child: Expression, plan: BaseSubqueryExec, exprId: ExprId, - private var resultBroadcast: Broadcast[Array[Any]] = null) extends ExecSubqueryExpression { + private var resultBroadcast: Broadcast[Array[Any]] = null) + extends ExecSubqueryExpression with UnaryLike[Expression] { @transient private var result: Array[Any] = _ @transient private lazy val inSet = InSet(child, result.toSet) override def dataType: DataType = BooleanType - override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = child.nullable override def toString: String = s"$child IN ${plan.name}" override def withNewPlan(plan: BaseSubqueryExec): InSubqueryExec = copy(plan = plan)