From 95ad039689f2bd667b4d930a7d774401aff6c757 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 May 2016 15:39:17 -0700 Subject: [PATCH 1/2] [SPARK-15636][SQL] Make aggregate expressions more concise in explain --- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/aggregate/interfaces.scala | 31 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b4fe151f277a2..2ec46216e1cdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -185,7 +185,7 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyName: String = nodeName.toLowerCase - private def flatArguments = productIterator.flatMap { + protected def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d31ccf9985360..f0f7883d782d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -24,14 +24,19 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ -private[sql] sealed trait AggregateMode +private[sql] sealed trait AggregateMode { + /** Prefix used in explain to indicate the aggregate mode. */ + def prefix: String +} /** * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object Partial extends AggregateMode +private[sql] case object Partial extends AggregateMode { + override def prefix: String = "partial_" +} /** * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers @@ -39,7 +44,9 @@ private[sql] case object Partial extends AggregateMode * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object PartialMerge extends AggregateMode +private[sql] case object PartialMerge extends AggregateMode { + override def prefix: String = "merge_" +} /** * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers @@ -47,7 +54,9 @@ private[sql] case object PartialMerge extends AggregateMode * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Final extends AggregateMode +private[sql] case object Final extends AggregateMode { + override def prefix: String = "" +} /** * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly @@ -55,7 +64,9 @@ private[sql] case object Final extends AggregateMode * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Complete extends AggregateMode +private[sql] case object Complete extends AggregateMode { + override def prefix: String = "" +} /** * A place holder expressions used in code-gen, it does not change the corresponding value @@ -126,7 +137,9 @@ private[sql] case class AggregateExpression( AttributeSet(childReferences) } - override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" + override def toString: String = { + mode.prefix + aggregateFunction.toAggString(isDistinct) + } override def sql: String = aggregateFunction.sql(isDistinct) } @@ -203,6 +216,12 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu val distinct = if (isDistinct) "DISTINCT " else "" s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" } + + /** String representation used in explain plans. */ + def toAggString(isDistinct: Boolean): String = { + val start = if (isDistinct) "(distinct " else "(" + prettyName + flatArguments.mkString(start, ", ", ")") + } } /** From 88f84608628332cf4dd04e281875483743d7adbe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 May 2016 18:16:45 -0700 Subject: [PATCH 2/2] Code review --- .../expressions/aggregate/interfaces.scala | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index f0f7883d782d1..504cea52797de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -24,19 +24,14 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ -private[sql] sealed trait AggregateMode { - /** Prefix used in explain to indicate the aggregate mode. */ - def prefix: String -} +private[sql] sealed trait AggregateMode /** * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object Partial extends AggregateMode { - override def prefix: String = "partial_" -} +private[sql] case object Partial extends AggregateMode /** * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers @@ -44,9 +39,7 @@ private[sql] case object Partial extends AggregateMode { * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object PartialMerge extends AggregateMode { - override def prefix: String = "merge_" -} +private[sql] case object PartialMerge extends AggregateMode /** * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers @@ -54,9 +47,7 @@ private[sql] case object PartialMerge extends AggregateMode { * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Final extends AggregateMode { - override def prefix: String = "" -} +private[sql] case object Final extends AggregateMode /** * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly @@ -64,9 +55,7 @@ private[sql] case object Final extends AggregateMode { * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Complete extends AggregateMode { - override def prefix: String = "" -} +private[sql] case object Complete extends AggregateMode /** * A place holder expressions used in code-gen, it does not change the corresponding value @@ -138,7 +127,12 @@ private[sql] case class AggregateExpression( } override def toString: String = { - mode.prefix + aggregateFunction.toAggString(isDistinct) + val prefix = mode match { + case Partial => "partial_" + case PartialMerge => "merge_" + case Final | Complete => "" + } + prefix + aggregateFunction.toAggString(isDistinct) } override def sql: String = aggregateFunction.sql(isDistinct)