diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5caf032fdf8a8..613b87ca98d97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -203,15 +203,6 @@ case class CollectHashSetFunction( @transient val distinctValue = new InterpretedProjection(expr) -/* - override def merge(other: MergableAggregateFunction): MergableAggregateFunction = { - val otherSetIterator = other.asInstanceOf[CountDistinctFunction].seen.iterator - while(otherSetIterator.hasNext) { - seen.add(otherSetIterator.next()) - } - this - }*/ - override def update(input: Row): Unit = { val evaluatedExpr = distinctValue(input) if (!evaluatedExpr.anyNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d9784882de805..923a9b1445d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -97,7 +97,9 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def dataType = left.dataType - override def eval(input: Row): Any = ??? + override def eval(input: Row): Any = { + val leftEval = left.eval(input) + } override def toString = s"MaxOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 14a95328c0a3a..b7433b48b82af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ import org.apache.spark.util.collection.OpenHashSet +/** + * Creates a new set of the specified type + */ case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any @@ -27,7 +30,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { def nullable = false - // This is not completely accurate.. + // We are currently only using these Expressions internally for aggregation. However, if we ever + // expose these to users we'll want to create a proper type instead of hijacking ArrayType. def dataType = ArrayType(elementType) def eval(input: Row): Any = { @@ -37,7 +41,10 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def toString = s"new Set($dataType)" } -// THIS MUTATES ITS ARUGMENTS +/** + * Adds an item to a set. + * For performance, this expression mutates its input during evaluation. + */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { type EvaluatedType = Any @@ -67,7 +74,10 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def toString = s"$set += $item" } -// THIS MUTATES ITS ARUGMENTS +/** + * Combines the elements of two sets. + * For performance, this expression mutates its left input set during evaluation. + */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { type EvaluatedType = Any @@ -97,6 +107,9 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } } +/** + * Returns the number of elements in the input set. + */ case class CountSet(child: Expression) extends UnaryExpression { type EvaluatedType = Any @@ -112,4 +125,4 @@ case class CountSet(child: Expression) extends UnaryExpression { } override def toString = s"$child.count()" -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 39a96b2dd218f..17418726aa625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -108,7 +108,6 @@ case class GeneratedAggregate( val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) - //If(IsNull(currentMax), expr, If(GreaterThan(currentMax, expr), currentMax, expr)) AggregateEvaluation( currentMax :: Nil, @@ -128,8 +127,9 @@ case class GeneratedAggregate( set) case CombineSetsAndCount(inputSet) => + val ArrayType(inputType) = inputSet.dataType val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() - val initialValue = NewSet(IntegerType) // NOT TRUE + val initialValue = NewSet(inputType) val collectSets = CombineSets(set, inputSet) AggregateEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 1c5995f82d1e1..4956948e83914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -45,10 +45,13 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], new HyperLogLogSerializer) kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) - // Specific hashset must come first + + // Specific hashsets must come first kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) kryo.register(classOf[LongHashSet], new LongHashSetSerializer) - kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) + kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], + new OpenHashSetSerializer) + kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) new AllScalaRegistrar().apply(kryo) @@ -188,4 +191,4 @@ private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { } set } -} \ No newline at end of file +}