diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 50b2a173153ef..332ca3405a58a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -21,6 +21,7 @@ import java.util.Iterator; import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; @@ -65,6 +66,32 @@ public final class UnsafeFixedWidthAggregationMap { */ private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; + /** + * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, + * false otherwise. + */ + public static boolean supportsGroupKeySchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. + */ + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + /** * Create a new UnsafeFixedWidthAggregationMap. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ce78ca6e73041..fe3d4d29b2204 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.types.DataType; import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; @@ -34,7 +35,10 @@ import javax.annotation.Nullable; import java.math.BigDecimal; import java.sql.Date; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Set; // TODO: pick a better name for this class, since this is potentially confusing. @@ -71,6 +75,34 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8; } + /** + * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) + */ + public static final Set settableFieldTypes; + + /** + * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). + */ + public static final Set readableFieldTypes; + + static { + settableFieldTypes = new HashSet(Arrays.asList(new DataType[] { + IntegerType, + LongType, + DoubleType, + BooleanType, + ShortType, + ByteType, + FloatType + })); + + // We support get() on a superset of the types for which we support set(): + readableFieldTypes = new HashSet(Arrays.asList(new DataType[] { + StringType + })); + readableFieldTypes.addAll(settableFieldTypes); + } + public UnsafeRow() { } public void set(Object baseObject, long baseOffset, int numFields, StructType schema) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 45695569c9cf2..956a80ade2f02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -24,10 +24,22 @@ import org.apache.spark.sql.types._ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers { + import UnsafeFixedWidthAggregationMap._ + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) + test("supported schemas") { + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) + + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + assert( + !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + test("empty map") { val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4fc5de7e824fe..361483a431e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,6 +30,7 @@ private[spark] object SQLConf { val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" @@ -149,6 +150,8 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index bcd20c06c6dca..04a8538c763c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1011,6 +1011,8 @@ class SQLContext(@transient val sparkContext: SparkContext) def codegenEnabled: Boolean = self.conf.codegenEnabled + def unsafeEnabled: Boolean = self.conf.unsafeEnabled + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b1ef6556de1e9..fd50693f265d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.MemoryAllocator case class AggregateEvaluation( schema: Seq[Attribute], @@ -41,13 +42,15 @@ case class AggregateEvaluation( * @param groupingExpressions expressions that are evaluated to determine grouping. * @param aggregateExpressions expressions that are computed for each group. * @param child the input data source. + * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. */ @DeveloperApi case class GeneratedAggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + unsafeEnabled: Boolean) extends UnaryNode { override def requiredChildDistribution: Seq[Distribution] = @@ -225,6 +228,21 @@ case class GeneratedAggregate( case e: Expression if groupMap.contains(e) => groupMap(e) }) + val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + // This is a dummy field name + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -265,7 +283,48 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) + } else if (unsafeEnabled && schemaSupportsUnsafe) { + log.info("Using Unsafe-based aggregator") + val aggregationMap = new UnsafeFixedWidthAggregationMap( + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, + MemoryAllocator.UNSAFE, + 1024 + ) + + while (iter.hasNext) { + val currentRow: Row = iter.next() + val groupKey: Row = groupProjection(currentRow) + val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): Row = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } + } + } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[Row, MutableRow]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4c132f11ba7f1..4c0369f0dbde4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -131,18 +131,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if canBeCodeGened( allAggregates(partialComputation) ++ allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled => { - if (self.sqlContext.getConf("spark.sql.unsafe.enabled", "false") == "true") { - execution.UnsafeGeneratedAggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.UnsafeGeneratedAggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil - } else { + codegenEnabled => execution.GeneratedAggregate( partial = false, namedGroupingAttributes, @@ -151,9 +140,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = true, groupingExpressions, partialComputation, - planLater(child))) :: Nil - } - } + planLater(child), + unsafeEnabled), + unsafeEnabled) :: Nil // Cases where some aggregate can not be codegened case PartialAggregation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala deleted file mode 100644 index 80c67da24ec8f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala +++ /dev/null @@ -1,305 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.trees._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.MemoryAllocator - -// TODO: finish cleaning up documentation instead of just copying it - -/** - * TODO: copy of GeneratedAggregate that uses unsafe / offheap row implementations + hashtables. - */ -@DeveloperApi -case class UnsafeGeneratedAggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - override def execute(): RDD[Row] = { - val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg} - } - - // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite - // (in test "aggregation with codegen"). - val computeFunctions = aggregatesToCompute.map { - case c @ Count(expr) => - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { - case UnscaledValue(e) => e - case _ => expr - } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val result = currentCount - - AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case s @ Sum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - val updateFunction = Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType) - ) :: currentSum :: zero :: Nil) - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, s.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case cs @ CombineSum(expr) => - val calcType = expr.dataType - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalasce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val actualExpr = expr match { - case UnscaledValue(e) => e - case _ => expr - } - // partial sum result can be null only when no input rows present - val updateFunction = If( - IsNotNull(actualExpr), - Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType)) :: currentSum :: zero :: Nil), - currentSum) - - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, cs.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case m @ Max(expr) => - val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMax = MaxOf(currentMax, expr) - - AggregateEvaluation( - currentMax :: Nil, - initialValue :: Nil, - updateMax :: Nil, - currentMax) - - case m @ Min(expr) => - val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMin = MinOf(currentMin, expr) - - AggregateEvaluation( - currentMin :: Nil, - initialValue :: Nil, - updateMin :: Nil, - currentMin) - - case CollectHashSet(Seq(expr)) => - val set = - AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() - val initialValue = NewSet(expr.dataType) - val addToSet = AddItemToSet(expr, set) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - addToSet :: Nil, - set) - - case CombineSetsAndCount(inputSet) => - val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType - val set = - AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() - val initialValue = NewSet(elementType) - val collectSets = CombineSets(set, inputSet) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - collectSets :: Nil, - CountSet(set)) - - case o => sys.error(s"$o can't be codegened.") - } - - val computationSchema: Seq[Attribute] = computeFunctions.flatMap(_.schema) - - val resultMap: Map[TreeNodeRef, Expression] = - aggregatesToCompute.zip(computeFunctions).map { - case (agg, func) => new TreeNodeRef(agg) -> func.result - }.toMap - - val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) - } - - val groupMap: Map[Expression, Attribute] = - namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap - - // The set of expressions that produce the final output given the aggregation buffer and the - // grouping expressions. - val resultExpressions = aggregateExpressions.map(_.transform { - case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression if groupMap.contains(e) => groupMap(e) - }) - - child.execute().mapPartitions { iter => - // Builds a new custom class for holding the results of aggregation for a group. - val initialValues = computeFunctions.flatMap(_.initialValues) - val newAggregationBuffer = newProjection(initialValues, child.output) - log.info(s"Initial values: ${initialValues.mkString(",")}") - - // A projection that computes the group given an input tuple. - val groupProjection = newProjection(groupingExpressions, child.output) - log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") - - // A projection that is used to update the aggregate values for a group given a new tuple. - // This projection should be targeted at the current values for the group and then applied - // to a joined row of the current values with the new input row. - val updateExpressions = computeFunctions.flatMap(_.update) - val updateSchema = computationSchema ++ child.output - val updateProjection = newMutableProjection(updateExpressions, updateSchema)() - log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") - - // A projection that produces the final result, given a computation. - val resultProjectionBuilder = - newMutableProjection( - resultExpressions, - (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) - log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - - val joinedRow = new JoinedRow3 - - if (groupingExpressions.isEmpty) { - // TODO: Codegening anything other than the updateProjection is probably over kill. - val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - var currentRow: Row = null - updateProjection.target(buffer) - - while (iter.hasNext) { - currentRow = iter.next() - updateProjection(joinedRow(buffer, currentRow)) - } - - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(buffer)) - } else { - val aggregationBufferSchema = StructType.fromAttributes(computationSchema) - - val groupKeySchema: StructType = { - val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => - StructField(idx.toString, expr.dataType, expr.nullable) - } - StructType(fields) - } - - val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, - MemoryAllocator.UNSAFE, - 1024 - ) - - while (iter.hasNext) { - val currentRow: Row = iter.next() - val groupKey: Row = groupProjection(currentRow) - val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) - updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) - } - - new Iterator[Row] { - private[this] val mapIterator = aggregationMap.iterator() - private[this] val resultProjection = resultProjectionBuilder() - - def hasNext: Boolean = mapIterator.hasNext - - def next(): Row = { - val entry = mapIterator.next() - val result = resultProjection(joinedRow(entry.key, entry.value)) - if (hasNext) { - result - } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory - val resultCopy = result.copy() - aggregationMap.free() - resultCopy - } - } - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b5ff228f45d35..61a1d3f268b12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.execution.{UnsafeGeneratedAggregate, GeneratedAggregate} +import org.apache.spark.sql.execution.{GeneratedAggregate} import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext @@ -114,7 +114,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // First, check if we have GeneratedAggregate. var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { - case generatedAgg: UnsafeGeneratedAggregate => hasGeneratedAgg = true case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true case _ => } @@ -137,16 +136,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { testCodeGen( "SELECT key, count(value) FROM testData3x GROUP BY key", (1 to 100).map(i => Row(i, 3))) -// testCodeGen( -// "SELECT count(key) FROM testData3x", -// Row(300) :: Nil) -// // COUNT DISTINCT ON int -// testCodeGen( -// "SELECT value, count(distinct key) FROM testData3x GROUP BY value", -// (1 to 100).map(i => Row(i.toString, 1))) -// testCodeGen( -// "SELECT count(distinct key) FROM testData3x", -// Row(100) :: Nil) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) // SUM testCodeGen( "SELECT value, sum(key) FROM testData3x GROUP BY value", @@ -176,23 +175,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { "SELECT min(key) FROM testData3x", Row(1) :: Nil) // Some combinations. -// testCodeGen( -// """ -// |SELECT -// | value, -// | sum(key), -// | max(key), -// | min(key), -// | avg(key), -// | count(key), -// | count(distinct key) -// |FROM testData3x -// |GROUP BY value -// """.stripMargin, -// (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) -// testCodeGen( -// "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", -// Row(100, 1, 50.5, 300, 100) :: Nil) + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( "SELECT sum('a'), avg('a'), count(null) FROM testData",