From 9d8aeed7b18d58258097526f85f07c37a10c28d8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 1 Feb 2016 16:15:16 +0100 Subject: [PATCH 1/6] Add native collect_set/collect_list. --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/aggregate/collect.scala | 179 ++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 12 +- .../spark/sql/DataFrameAggregateSuite.scala | 12 ++ .../hive/HiveDataFrameAnalyticsSuite.scala | 11 -- 5 files changed, 196 insertions(+), 20 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d9009e3848e58..6b23c2ed4d395 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -200,6 +200,8 @@ object FunctionRegistry { expression[VarianceSamp]("var_samp"), expression[Skewness]("skewness"), expression[Kurtosis]("kurtosis"), + expression[CollectList]("collect_list"), + expression[CollectSet]("collect_set"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala new file mode 100644 index 0000000000000..c7438a575ea1f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import scala.collection.generic.Growable +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +/** + * The Collect aggregate function collects all seen expression values into a list of values. + * + * The operator is bound to the slower sort based aggregation path because the number of + * elements (and their memory usage) can not be determined in advance. This also means that the + * collected elements are stored on heap, and that too many elements can cause GC pauses and + * eventually Out of Memory Errors. + */ +abstract class Collect extends ImperativeAggregate { + + val child: Expression + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = ArrayType(child.dataType) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + // We need to keep track of the expression id of the list because the dataType of the attribute + // (and the attribute itself) will change when the dataType of the child gets resolved. + val listExprId = NamedExpression.newExprId + + override def aggBufferAttributes: Seq[AttributeReference] = { + Seq(AttributeReference("list", dataType, nullable = false)(listExprId)) + } + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override def inputAggBufferAttributes: Seq[AttributeReference] = { + aggBufferAttributes.map(_.newInstance()) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + getMutableArray(buffer) += child.eval(input) + } + + override def merge(buffer: MutableRow, input: InternalRow): Unit = { + getMutableArray(buffer) ++= input.getArray(inputAggBufferOffset) + } + + override def eval(input: InternalRow): Any = { + // TODO return null if there are no elements? + getMutableArray(input).toFastRandomAccess + } + + private def getMutableArray(buffer: InternalRow): MutableArrayData = { + buffer.getArray(mutableAggBufferOffset).asInstanceOf[MutableArrayData] + } +} + +case class CollectList( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect { + + def this(child: Expression) = this(child, 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def initialize(mutableAggBuffer: MutableRow): Unit = { + mutableAggBuffer.update(mutableAggBufferOffset, ListMutableArrayData()) + } +} + +case class CollectSet( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect { + + def this(child: Expression) = this(child, 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def initialize(mutableAggBuffer: MutableRow): Unit = { + mutableAggBuffer.update(mutableAggBufferOffset, SetMutableArrayData()) + } +} + +/** + * MutableArrayData is an implementation of ArrayData that can be updated in place. This makes + * the assumption that the buffer holding this object data keeps a reference to this object. This + * means that this approach is only valid if a GenericInternalRow or a SpecializedInternalRow is + * used as a buffer. + */ +abstract class MutableArrayData extends ArrayData { + val buffer: Growable[Any] with Iterable[Any] + + /** Add a single element to the MutableArrayData. */ + def +=(elem: Any): MutableArrayData = { + buffer += elem + this + } + + /** Add another array to the MutableArrayData. */ + def ++=(elems: ArrayData): MutableArrayData = { + elems match { + case input: MutableArrayData => buffer ++= input.buffer + case input => buffer ++= input.array + } + this + } + + /** Return an ArrayData instance with fast random access properties. */ + def toFastRandomAccess: ArrayData = this + + protected def getAs[T](ordinal: Int): T + + /* ArrayData methods. */ + override def numElements(): Int = buffer.size + override def array: Array[Any] = buffer.toArray + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) +} + +case class ListMutableArrayData( + val buffer: ArrayBuffer[Any] = ArrayBuffer.empty) extends MutableArrayData { + override protected def getAs[T](ordinal: Int): T = buffer(ordinal).asInstanceOf[T] + override def copy(): ListMutableArrayData = ListMutableArrayData(buffer.clone()) +} + +case class SetMutableArrayData( + val buffer: mutable.HashSet[Any] = mutable.HashSet.empty) extends MutableArrayData { + override protected def getAs[T](ordinal: Int): T = buffer.toArray.apply(ordinal).asInstanceOf[T] + override def copy(): SetMutableArrayData = SetMutableArrayData(buffer.clone()) + override def toFastRandomAccess: GenericArrayData = new GenericArrayData(buffer.toArray) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a27466176a20..5731c150504aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -212,13 +212,11 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def collect_list(e: Column): Column = callUDF("collect_list", e) + def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) } /** * Aggregate function: returns a list of objects with duplicates. * - * For now this is an alias for the collect_list Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ @@ -227,17 +225,13 @@ object functions extends LegacyFunctions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * - * For now this is an alias for the collect_set Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ - def collect_set(e: Column): Column = callUDF("collect_set", e) + def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) } /** - * Aggregate function: returns a set of objects with duplicate elements eliminated. - * - * For now this is an alias for the collect_set Hive UDAF. + * Aggregate function: returns a set of objects with duplicate elements eliminated. * * @group agg_funcs * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 08fb7c9d84c0b..8959020a4251c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -361,4 +361,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("kurtosis(a)")), Row(null, null, null, null, null)) } + + test("collect functions") { + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 35e433964da91..af536b45f034d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -53,17 +53,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } - test("collect functions") { - checkAnswer( - testData.select(collect_list($"a"), collect_list($"b")), - Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) - ) - checkAnswer( - testData.select(collect_set($"a"), collect_set($"b")), - Seq(Row(Seq(1, 2, 3), Seq(2, 4))) - ) - } - test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), From 8247d8eb6416b7b5d4eb61fa585c595d87930d63 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 1 Feb 2016 16:43:24 +0100 Subject: [PATCH 2/6] Add test for struct types. --- .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8959020a4251c..86f88362cef82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -373,4 +373,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(Seq(1, 2, 3), Seq(2, 4))) ) } + + test("collect functions structs") { + val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1)) + .toDF("a", "x", "y") + .select($"a", struct($"x", $"y").as("b")) + checkAnswer( + df.select(collect_list($"a"), sort_array(collect_list($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1)))) + ) + checkAnswer( + df.select(collect_set($"a"), sort_array(collect_set($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1)))) + ) + } } From 326a213dc014403aef1033e9d39206f5a873b7a6 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 1 Feb 2016 19:14:10 +0100 Subject: [PATCH 3/6] Add pretty names for SQL generation. --- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index c7438a575ea1f..cde0a963d2192 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -95,6 +95,8 @@ case class CollectList( override def initialize(mutableAggBuffer: MutableRow): Unit = { mutableAggBuffer.update(mutableAggBufferOffset, ListMutableArrayData()) } + + override def prettyName: String = "collect_list" } case class CollectSet( @@ -113,6 +115,8 @@ case class CollectSet( override def initialize(mutableAggBuffer: MutableRow): Unit = { mutableAggBuffer.update(mutableAggBufferOffset, SetMutableArrayData()) } + + override def prettyName: String = "collect_set" } /** From 8a4e7827d6b1f4c150ec29c35185e63c974762dd Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 3 May 2016 20:32:22 +0200 Subject: [PATCH 4/6] Merge remote-tracking branch 'apache-github/master' into implode # Conflicts: # sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala # sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala --- .../expressions/aggregate/collect.scala | 106 +++--------------- .../org/apache/spark/sql/functions.scala | 4 +- 2 files changed, 17 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index cde0a963d2192..d33fcc2310cb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import scala.collection.generic.Growable import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * The Collect aggregate function collects all seen expression values into a list of values. @@ -47,35 +45,30 @@ abstract class Collect extends ImperativeAggregate { override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - // We need to keep track of the expression id of the list because the dataType of the attribute - // (and the attribute itself) will change when the dataType of the child gets resolved. - val listExprId = NamedExpression.newExprId + override def supportsPartial: Boolean = false - override def aggBufferAttributes: Seq[AttributeReference] = { - Seq(AttributeReference("list", dataType, nullable = false)(listExprId)) - } + override def aggBufferAttributes: Seq[AttributeReference] = Nil override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) + override def inputAggBufferAttributes: Seq[AttributeReference] = Nil + + protected[this] val buffer: Growable[Any] with Iterable[Any] + + override def initialize(b: MutableRow): Unit = { + buffer.clear() } - override def update(buffer: MutableRow, input: InternalRow): Unit = { - getMutableArray(buffer) += child.eval(input) + override def update(b: MutableRow, input: InternalRow): Unit = { + buffer += child.eval(input) } override def merge(buffer: MutableRow, input: InternalRow): Unit = { - getMutableArray(buffer) ++= input.getArray(inputAggBufferOffset) + sys.error("Collect cannot be used in partial aggregations.") } override def eval(input: InternalRow): Any = { - // TODO return null if there are no elements? - getMutableArray(input).toFastRandomAccess - } - - private def getMutableArray(buffer: InternalRow): MutableArrayData = { - buffer.getArray(mutableAggBufferOffset).asInstanceOf[MutableArrayData] + new GenericArrayData(buffer.toArray) } } @@ -92,11 +85,9 @@ case class CollectList( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def initialize(mutableAggBuffer: MutableRow): Unit = { - mutableAggBuffer.update(mutableAggBufferOffset, ListMutableArrayData()) - } - override def prettyName: String = "collect_list" + + override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty } case class CollectSet( @@ -112,72 +103,7 @@ case class CollectSet( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def initialize(mutableAggBuffer: MutableRow): Unit = { - mutableAggBuffer.update(mutableAggBufferOffset, SetMutableArrayData()) - } - override def prettyName: String = "collect_set" -} - -/** - * MutableArrayData is an implementation of ArrayData that can be updated in place. This makes - * the assumption that the buffer holding this object data keeps a reference to this object. This - * means that this approach is only valid if a GenericInternalRow or a SpecializedInternalRow is - * used as a buffer. - */ -abstract class MutableArrayData extends ArrayData { - val buffer: Growable[Any] with Iterable[Any] - - /** Add a single element to the MutableArrayData. */ - def +=(elem: Any): MutableArrayData = { - buffer += elem - this - } - - /** Add another array to the MutableArrayData. */ - def ++=(elems: ArrayData): MutableArrayData = { - elems match { - case input: MutableArrayData => buffer ++= input.buffer - case input => buffer ++= input.array - } - this - } - - /** Return an ArrayData instance with fast random access properties. */ - def toFastRandomAccess: ArrayData = this - - protected def getAs[T](ordinal: Int): T - - /* ArrayData methods. */ - override def numElements(): Int = buffer.size - override def array: Array[Any] = buffer.toArray - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) -} - -case class ListMutableArrayData( - val buffer: ArrayBuffer[Any] = ArrayBuffer.empty) extends MutableArrayData { - override protected def getAs[T](ordinal: Int): T = buffer(ordinal).asInstanceOf[T] - override def copy(): ListMutableArrayData = ListMutableArrayData(buffer.clone()) -} -case class SetMutableArrayData( - val buffer: mutable.HashSet[Any] = mutable.HashSet.empty) extends MutableArrayData { - override protected def getAs[T](ordinal: Int): T = buffer.toArray.apply(ordinal).asInstanceOf[T] - override def copy(): SetMutableArrayData = SetMutableArrayData(buffer.clone()) - override def toFastRandomAccess: GenericArrayData = new GenericArrayData(buffer.toArray) + override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 374f031ae259a..d817dd6d30380 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -195,8 +195,6 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * - * For now this is an alias for the collect_list Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ @@ -219,7 +217,7 @@ object functions { def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) } /** - * Aggregate function: returns a set of objects with duplicate elements eliminated. + * Aggregate function: returns a set of objects with duplicate elements eliminated. * * @group agg_funcs * @since 1.6.0 From 597f76bb350126bb3360b692720426a6d41c3c18 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 3 May 2016 20:48:39 +0200 Subject: [PATCH 5/6] Remove hardcoded Hive references. --- .../spark/sql/hive/HiveSessionCatalog.scala | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index f023edbd96dbe..f847b34591263 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -219,20 +219,4 @@ private[sql] class HiveSessionCatalog( } } } - - // Pre-load a few commonly used Hive built-in functions. - HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach { - case (functionName, clazz) => - val builder = makeFunctionBuilder(functionName, clazz) - val info = new ExpressionInfo(clazz.getCanonicalName, functionName) - createTempFunction(functionName, info, builder, ignoreIfExists = false) - } -} - -private[sql] object HiveSessionCatalog { - // This is the list of Hive's built-in functions that are commonly used and we want to - // pre-load when we create the FunctionRegistry. - val preloadedHiveBuiltinFunctions = - ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) :: - ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil } From d9dedffc6d2a90a140264e89d971486c5a850dda Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 3 May 2016 22:39:30 +0200 Subject: [PATCH 6/6] Fix docs. --- .../sql/catalyst/expressions/aggregate/collect.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index d33fcc2310cb2..1f4ff9c4b184e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -72,6 +72,11 @@ abstract class Collect extends ImperativeAggregate { } } +/** + * Collect a list of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.") case class CollectList( child: Expression, mutableAggBufferOffset: Int = 0, @@ -90,6 +95,11 @@ case class CollectList( override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty } +/** + * Collect a list of unique elements. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Collects and returns a set of unique elements.") case class CollectSet( child: Expression, mutableAggBufferOffset: Int = 0,