From 4aaef15844848e16dc68d4be4d0a013c8ddcd7d7 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Wed, 11 Jan 2017 00:00:34 +0100 Subject: [PATCH 1/6] Rewrote sequence deserialization implementation to use builders --- .../spark/sql/catalyst/ScalaReflection.scala | 50 +----- .../expressions/objects/objects.scala | 166 ++++++++++++++++++ .../sql/catalyst/ScalaReflectionSuite.scala | 8 - 3 files changed, 168 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f7dd51aa2650..8c3c2286f04f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,54 +307,8 @@ object ScalaReflection extends ScalaReflection { } } - val array = Invoke( - MapObjects(mapFunction, getPath, dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val wrappedArray = StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[Seq[_]]), - "make", - array :: Nil) - - if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { - wrappedArray - } else { - // Convert to another type using `to` - val cls = mirror.runtimeClass(t.typeSymbol.asClass) - import scala.collection.generic.CanBuildFrom - import scala.reflect.ClassTag - - // Some canBuildFrom methods take an implicit ClassTag parameter - val cbfParams = try { - cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) - StaticInvoke( - ClassTag.getClass, - ObjectType(classOf[ClassTag[_]]), - "apply", - StaticInvoke( - cls, - ObjectType(classOf[Class[_]]), - "getClass" - ) :: Nil - ) :: Nil - } catch { - case _: NoSuchMethodException => Nil - } - - Invoke( - wrappedArray, - "to", - ObjectType(cls), - StaticInvoke( - cls, - ObjectType(classOf[CanBuildFrom[_, _, _]]), - "canBuildFrom", - cbfParams - ) :: Nil - ) - } + val cls = mirror.runtimeClass(t.typeSymbol.asClass) + CollectObjects(mapFunction, getPath, dataType, cls) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 36bf3017d4cdb..abd67f033795c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -589,6 +590,171 @@ case class MapObjects private( } } +object CollectObjects { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Construct an instance of CollectObjects case class. + * + * @param function The function applied on the collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param elementType The data type of elements in the collection. + * @param collClass The type of the resulting collection. + */ + def apply( + function: Expression => Expression, + inputData: Expression, + elementType: DataType, + collClass: Class[_]): CollectObjects = { + val loopValue = "CollectObjects_loopValue" + curId.getAndIncrement() + val loopIsNull = "CollectObjects_loopIsNull" + curId.getAndIncrement() + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + val builderValue = "CollectObjects_builderValue" + curId.getAndIncrement() + CollectObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, + collClass, builderValue) + } +} + +/** + * An equivalent to the [[MapObjects]] case class but returning an ObjectType containing + * a Scala collection constructed using the associated builder, obtained by calling `newBuilder` + * on the collection's companion object. + * + * @param loopValue the name of the loop variable that used when iterate the collection, and used + * as input for the `lambdaFunction` + * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and + * used as input for the `lambdaFunction` + * @param loopVarDataType the data type of the loop variable that used when iterate the collection, + * and used as input for the `lambdaFunction` + * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function + * to handle collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param collClass The type of the resulting collection. + * @param builderValue The name of the builder variable used to construct the resulting collection. + */ +case class CollectObjects private( + loopValue: String, + loopIsNull: String, + loopVarDataType: DataType, + lambdaFunction: Expression, + inputData: Expression, + collClass: Class[_], + builderValue: String) extends Expression with NonSQLExpression { + + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ObjectType(collClass) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val collObjectName = s"${collClass.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" + val elementJavaType = ctx.javaType(loopVarDataType) + ctx.addMutableState("boolean", loopIsNull, "") + ctx.addMutableState(elementJavaType, loopValue, "") + val genInputData = inputData.genCode(ctx) + val genFunction = lambdaFunction.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val convertedArray = ctx.freshName("convertedArray") + val loopIndex = ctx.freshName("loopIndex") + + val convertedType = ctx.boxedType(lambdaFunction.dataType) + + // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type + // of input collection at runtime for this case. + val seq = ctx.freshName("seq") + val array = ctx.freshName("array") + val determineCollectionType = inputData.dataType match { + case ObjectType(cls) if cls == classOf[Object] => + val seqClass = classOf[Seq[_]].getName + s""" + $seqClass $seq = null; + $elementJavaType[] $array = null; + if (${genInputData.value}.getClass().isArray()) { + $array = ($elementJavaType[]) ${genInputData.value}; + } else { + $seq = ($seqClass) ${genInputData.value}; + } + """ + case _ => "" + } + + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + val inputDataType = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } + + val (getLength, getLoopVar) = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" + case ObjectType(cls) if cls.isArray => + s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" + case ArrayType(et, _) => + s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) + case ObjectType(cls) if cls == classOf[Object] => + s"$seq == null ? $array.length : $seq.size()" -> + s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" + } + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" + val genFunctionValue = lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + + val loopNullCheck = inputDataType match { + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + // The element of primitive array will never be null. + case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => + s"$loopIsNull = false" + case _ => s"$loopIsNull = $loopValue == null;" + } + + val code = s""" + ${genInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + + if (!${genInputData.isNull}) { + $determineCollectionType + $convertedType[] $convertedArray = null; + int $dataLength = $getLength; + ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength); + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $loopValue = ($elementJavaType) ($getLoopVar); + $loopNullCheck + + ${genFunction.code} + if (${genFunction.isNull}) { + $builderValue.$$plus$$eq(null); + } else { + $builderValue.$$plus$$eq($genFunctionValue); + } + + $loopIndex += 1; + } + + ${ev.value} = (${collClass.getName}) $builderValue.result(); + } + """ + ev.copy(code = code, isNull = genInputData.isNull) + } +} + object ExternalMapToCatalyst { private val curId = new java.util.concurrent.atomic.AtomicInteger() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 650a35398f3e8..70ad064f93ebc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite { ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) - - // Check whether conversion is skipped when using WrappedArray[_] supertype - // (would otherwise needlessly add overhead) - import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke - val seqDeserializer = deserializerFor[Seq[Int]] - assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == - scala.collection.mutable.WrappedArray.getClass) - assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") } private val dataTypeForComplexData = dataTypeFor[ComplexData] From a330d5f54adba4e1589a6066d64414212b9e5b81 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Wed, 11 Jan 2017 22:25:49 +0100 Subject: [PATCH 2/6] Removed unused codegen line --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index abd67f033795c..d39fad240f726 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -728,7 +728,6 @@ case class CollectObjects private( if (!${genInputData.isNull}) { $determineCollectionType - $convertedType[] $convertedArray = null; int $dataLength = $getLength; ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; $builderValue.sizeHint($dataLength); From 10923754a7eceeaa2bab75e97d2a8a7d04bdd592 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Fri, 13 Jan 2017 02:05:48 +0100 Subject: [PATCH 3/6] Fallback to Seq builder if no builder found (e.g., for Range) --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8c3c2286f04f3..31ae1c390dc9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,7 +307,10 @@ object ScalaReflection extends ScalaReflection { } } - val cls = mirror.runtimeClass(t.typeSymbol.asClass) + val cls = t.companion.decl(TermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) + } CollectObjects(mapFunction, getPath, dataType, cls) case t if t <:< localTypeOf[Map[_, _]] => From 85edddd577c3fab85572689f35fe5f2774046341 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Thu, 19 Jan 2017 01:11:55 +0100 Subject: [PATCH 4/6] Add benchmarks --- .../benchmark/SequenceBenchmark.scala | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala new file mode 100644 index 0000000000000..1d813c0838b6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[Seq]], [[List]] and [[scala.collection.mutable.Queue]] serialization + * performance. + * To run this: + * 1. replace ignore(...) with test(...) + * 2. build/sbt "sql/test-only *benchmark.SequenceBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class SequenceBenchmark extends BenchmarkBase { + val partitions = 20 + val rows = 1000 + val size = 100 + + def generate[T <: Seq[Int]](generator: Int => ( => Int) => T): Seq[T] = { + Seq.fill(rows)(generator(size)(1)) + } + + ignore("Collect sequence types") { + import sparkSession.implicits._ + + val sc = sparkSession.sparkContext + + val benchmark = new Benchmark(s"collect", rows) + + val seq = generate(Seq.fill(_)) + benchmark.addCase("Seq") { _ => + sc.parallelize(seq, partitions).toDS().map(identity).queryExecution.toRdd.collect().length + } + + val list = generate(List.fill(_)) + benchmark.addCase("List") { _ => + sc.parallelize(list, partitions).toDS().map(identity).queryExecution.toRdd.collect().length + } + + val queue = generate(scala.collection.mutable.Queue.fill(_)) + benchmark.addCase("mutable.Queue") { _ => + sc.parallelize(queue, partitions).toDS().map(identity).queryExecution.toRdd.collect().length + } + + benchmark.run() + + /* + OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH + AMD A10-4600M APU with Radeon(tm) HD Graphics + collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Seq 255 / 316 0.0 254697.3 1.0X + List 152 / 177 0.0 152410.0 1.7X + mutable.Queue 213 / 235 0.0 213470.0 1.2X + */ + } +} From b5f87bd37bf3dadda31bc38ddcdd3f6e524b9bf0 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 19 Mar 2017 17:04:05 +0100 Subject: [PATCH 5/6] Changes based on code review Merge CollectObjects with MapObjects Remove SequenceBenchmark --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../expressions/objects/objects.scala | 220 ++++-------------- .../benchmark/SequenceBenchmark.scala | 74 ------ 3 files changed, 45 insertions(+), 251 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 31ae1c390dc9d..278223db4e11b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -311,7 +311,7 @@ object ScalaReflection extends ScalaReflection { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - CollectObjects(mapFunction, getPath, dataType, cls) + MapObjects(mapFunction, getPath, dataType, cls) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d39fad240f726..b6b2c754fa677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -430,24 +430,33 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param collClass The class of the resulting collection */ def apply( function: Expression => Expression, inputData: Expression, - elementType: DataType): MapObjects = { - val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + elementType: DataType, + collClass: Class[_] = classOf[Array[_]]): MapObjects = { + val id = curId.getAndIncrement() + val loopValue = s"MapObjects_loopValue$id" + val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) + val builderValue = s"MapObjects_builderValue$id" + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, + collClass, builderValue) } } /** * Applies the given expression to every element of a collection of items, returning the result - * as an ArrayType. This is similar to a typical map operation, but where the lambda function - * is expressed using catalyst expressions. + * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda + * function is expressed using catalyst expressions. + * + * The type of the result is determined as follows: + * - ArrayType - when collClass is an array class + * - ObjectType(collClass) - when collClass is a collection class * - * The following collection ObjectTypes are currently supported: + * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List * * @param loopValue the name of the loop variable that used when iterate the collection, and used @@ -459,13 +468,18 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. + * @param collClass The class of the resulting collection + * @param builderValue The name of the builder variable used to construct the resulting collection + * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, loopIsNull: String, loopVarDataType: DataType, lambdaFunction: Expression, - inputData: Expression) extends Expression with NonSQLExpression { + inputData: Expression, + collClass: Class[_], + builderValue: String) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -475,7 +489,8 @@ case class MapObjects private( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def dataType: DataType = - ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) + if (!collClass.isArray) ObjectType(collClass) + else ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -558,169 +573,23 @@ case class MapObjects private( case _ => s"$loopIsNull = $loopValue == null;" } - val code = s""" - ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - - if (!${genInputData.isNull}) { - $determineCollectionType - $convertedType[] $convertedArray = null; - int $dataLength = $getLength; - $convertedArray = $arrayConstructor; - - int $loopIndex = 0; - while ($loopIndex < $dataLength) { - $loopValue = ($elementJavaType) ($getLoopVar); - $loopNullCheck - - ${genFunction.code} - if (${genFunction.isNull}) { - $convertedArray[$loopIndex] = null; - } else { - $convertedArray[$loopIndex] = $genFunctionValue; - } - - $loopIndex += 1; - } - - ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + val (genInit, genAssign, genResult): (String, String => String, String) = + if (collClass.isArray) { + // array + (s"""$convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor;""", + genValue => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);") + } else { + // collection + val collObjectName = s"${collClass.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" + + (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength);""", + genValue => s"$builderValue.$$plus$$eq($genValue);", + s"(${collClass.getName}) $builderValue.result();") } - """ - ev.copy(code = code, isNull = genInputData.isNull) - } -} - -object CollectObjects { - private val curId = new java.util.concurrent.atomic.AtomicInteger() - - /** - * Construct an instance of CollectObjects case class. - * - * @param function The function applied on the collection elements. - * @param inputData An expression that when evaluated returns a collection object. - * @param elementType The data type of elements in the collection. - * @param collClass The type of the resulting collection. - */ - def apply( - function: Expression => Expression, - inputData: Expression, - elementType: DataType, - collClass: Class[_]): CollectObjects = { - val loopValue = "CollectObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "CollectObjects_loopIsNull" + curId.getAndIncrement() - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - val builderValue = "CollectObjects_builderValue" + curId.getAndIncrement() - CollectObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, - collClass, builderValue) - } -} - -/** - * An equivalent to the [[MapObjects]] case class but returning an ObjectType containing - * a Scala collection constructed using the associated builder, obtained by calling `newBuilder` - * on the collection's companion object. - * - * @param loopValue the name of the loop variable that used when iterate the collection, and used - * as input for the `lambdaFunction` - * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and - * used as input for the `lambdaFunction` - * @param loopVarDataType the data type of the loop variable that used when iterate the collection, - * and used as input for the `lambdaFunction` - * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function - * to handle collection elements. - * @param inputData An expression that when evaluated returns a collection object. - * @param collClass The type of the resulting collection. - * @param builderValue The name of the builder variable used to construct the resulting collection. - */ -case class CollectObjects private( - loopValue: String, - loopIsNull: String, - loopVarDataType: DataType, - lambdaFunction: Expression, - inputData: Expression, - collClass: Class[_], - builderValue: String) extends Expression with NonSQLExpression { - - override def nullable: Boolean = inputData.nullable - - override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - - override def dataType: DataType = ObjectType(collClass) - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val collObjectName = s"${collClass.getName}$$.MODULE$$" - val getBuilderVar = s"$collObjectName.newBuilder()" - val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState("boolean", loopIsNull, "") - ctx.addMutableState(elementJavaType, loopValue, "") - val genInputData = inputData.genCode(ctx) - val genFunction = lambdaFunction.genCode(ctx) - val dataLength = ctx.freshName("dataLength") - val convertedArray = ctx.freshName("convertedArray") - val loopIndex = ctx.freshName("loopIndex") - - val convertedType = ctx.boxedType(lambdaFunction.dataType) - - // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type - // of input collection at runtime for this case. - val seq = ctx.freshName("seq") - val array = ctx.freshName("array") - val determineCollectionType = inputData.dataType match { - case ObjectType(cls) if cls == classOf[Object] => - val seqClass = classOf[Seq[_]].getName - s""" - $seqClass $seq = null; - $elementJavaType[] $array = null; - if (${genInputData.value}.getClass().isArray()) { - $array = ($elementJavaType[]) ${genInputData.value}; - } else { - $seq = ($seqClass) ${genInputData.value}; - } - """ - case _ => "" - } - - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - - val (getLength, getLoopVar) = inputDataType match { - case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" - case ObjectType(cls) if cls.isArray => - s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" - case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" - case ArrayType(et, _) => - s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) - case ObjectType(cls) if cls == classOf[Object] => - s"$seq == null ? $array.length : $seq.size()" -> - s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" - } - - // Make a copy of the data if it's unsafe-backed - def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = - s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { - case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) - case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) - case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) - case _ => genFunction.value - } - - val loopNullCheck = inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - // The element of primitive array will never be null. - case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"$loopIsNull = false" - case _ => s"$loopIsNull = $loopValue == null;" - } val code = s""" ${genInputData.code} @@ -729,8 +598,7 @@ case class CollectObjects private( if (!${genInputData.isNull}) { $determineCollectionType int $dataLength = $getLength; - ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength); + $genInit int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -739,15 +607,15 @@ case class CollectObjects private( ${genFunction.code} if (${genFunction.isNull}) { - $builderValue.$$plus$$eq(null); + ${genAssign("null")} } else { - $builderValue.$$plus$$eq($genFunctionValue); + ${genAssign(genFunctionValue)} } $loopIndex += 1; } - ${ev.value} = (${collClass.getName}) $builderValue.result(); + ${ev.value} = $genResult } """ ev.copy(code = code, isNull = genInputData.isNull) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala deleted file mode 100644 index 1d813c0838b6c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.benchmark - -import org.apache.spark.util.Benchmark - -/** - * Benchmark [[Seq]], [[List]] and [[scala.collection.mutable.Queue]] serialization - * performance. - * To run this: - * 1. replace ignore(...) with test(...) - * 2. build/sbt "sql/test-only *benchmark.SequenceBenchmark" - * - * Benchmarks in this file are skipped in normal builds. - */ -class SequenceBenchmark extends BenchmarkBase { - val partitions = 20 - val rows = 1000 - val size = 100 - - def generate[T <: Seq[Int]](generator: Int => ( => Int) => T): Seq[T] = { - Seq.fill(rows)(generator(size)(1)) - } - - ignore("Collect sequence types") { - import sparkSession.implicits._ - - val sc = sparkSession.sparkContext - - val benchmark = new Benchmark(s"collect", rows) - - val seq = generate(Seq.fill(_)) - benchmark.addCase("Seq") { _ => - sc.parallelize(seq, partitions).toDS().map(identity).queryExecution.toRdd.collect().length - } - - val list = generate(List.fill(_)) - benchmark.addCase("List") { _ => - sc.parallelize(list, partitions).toDS().map(identity).queryExecution.toRdd.collect().length - } - - val queue = generate(scala.collection.mutable.Queue.fill(_)) - benchmark.addCase("mutable.Queue") { _ => - sc.parallelize(queue, partitions).toDS().map(identity).queryExecution.toRdd.collect().length - } - - benchmark.run() - - /* - OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH - AMD A10-4600M APU with Radeon(tm) HD Graphics - collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Seq 255 / 316 0.0 254697.3 1.0X - List 152 / 177 0.0 152410.0 1.7X - mutable.Queue 213 / 235 0.0 213470.0 1.2X - */ - } -} From d04e043fcd00204531553cb0a8ac1148d85436f4 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 26 Mar 2017 14:11:18 +0200 Subject: [PATCH 6/6] Type alias bug fix & changes based on code review Dealias collection type before obtaining its companion object Change collClass to Option Rename variables --- .../spark/sql/catalyst/ScalaReflection.scala | 4 +- .../expressions/objects/objects.scala | 57 ++++++++++--------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 278223db4e11b..96acf894d191b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,11 +307,11 @@ object ScalaReflection extends ScalaReflection { } } - val cls = t.companion.decl(TermName("newBuilder")) match { + val cls = t.dealias.companion.decl(TermName("newBuilder")) match { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - MapObjects(mapFunction, getPath, dataType, cls) + MapObjects(mapFunction, getPath, dataType, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b6b2c754fa677..ab09c9ad925c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -430,20 +430,21 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. - * @param collClass The class of the resulting collection + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) */ def apply( function: Expression => Expression, inputData: Expression, elementType: DataType, - collClass: Class[_] = classOf[Array[_]]): MapObjects = { + customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) val builderValue = s"MapObjects_builderValue$id" MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, - collClass, builderValue) + customCollectionCls, builderValue) } } @@ -453,8 +454,8 @@ object MapObjects { * function is expressed using catalyst expressions. * * The type of the result is determined as follows: - * - ArrayType - when collClass is an array class - * - ObjectType(collClass) - when collClass is a collection class + * - ArrayType - when customCollectionCls is None + * - ObjectType(collection) - when customCollectionCls contains a collection class * * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List @@ -468,7 +469,8 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. - * @param collClass The class of the resulting collection + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) * @param builderValue The name of the builder variable used to construct the resulting collection * (used only when returning ObjectType) */ @@ -478,7 +480,7 @@ case class MapObjects private( loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression, - collClass: Class[_], + customCollectionCls: Option[Class[_]], builderValue: String) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -489,8 +491,8 @@ case class MapObjects private( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def dataType: DataType = - if (!collClass.isArray) ObjectType(collClass) - else ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) + customCollectionCls.map(ObjectType.apply).getOrElse( + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -573,22 +575,23 @@ case class MapObjects private( case _ => s"$loopIsNull = $loopValue == null;" } - val (genInit, genAssign, genResult): (String, String => String, String) = - if (collClass.isArray) { - // array - (s"""$convertedType[] $convertedArray = null; - $convertedArray = $arrayConstructor;""", - genValue => s"$convertedArray[$loopIndex] = $genValue;", - s"new ${classOf[GenericArrayData].getName}($convertedArray);") - } else { - // collection - val collObjectName = s"${collClass.getName}$$.MODULE$$" - val getBuilderVar = s"$collObjectName.newBuilder()" + val (initCollection, addElement, getResult): (String, String => String, String) = + customCollectionCls match { + case Some(cls) => + // collection + val collObjectName = s"${cls.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" - (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; $builderValue.sizeHint($dataLength);""", - genValue => s"$builderValue.$$plus$$eq($genValue);", - s"(${collClass.getName}) $builderValue.result();") + genValue => s"$builderValue.$$plus$$eq($genValue);", + s"(${cls.getName}) $builderValue.result();") + case None => + // array + (s"""$convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor;""", + genValue => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);") } val code = s""" @@ -598,7 +601,7 @@ case class MapObjects private( if (!${genInputData.isNull}) { $determineCollectionType int $dataLength = $getLength; - $genInit + $initCollection int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -607,15 +610,15 @@ case class MapObjects private( ${genFunction.code} if (${genFunction.isNull}) { - ${genAssign("null")} + ${addElement("null")} } else { - ${genAssign(genFunctionValue)} + ${addElement(genFunctionValue)} } $loopIndex += 1; } - ${ev.value} = $genResult + ${ev.value} = $getResult } """ ev.copy(code = code, isNull = genInputData.isNull)