From 8163966a65e18f64ce371383fe31075bd8f1281d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Nov 2015 01:51:02 +0800 Subject: [PATCH] simplify encoder framework --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 143 +++--- .../sql/catalyst/encoders/FlatEncoder.scala | 50 ++ .../catalyst/encoders/ProductEncoder.scala | 463 ++++++++++++++++++ .../sql/catalyst/encoders/RowEncoder.scala | 58 +-- .../sql/catalyst/expressions/objects.scala | 3 +- .../plans/logical/basicOperators.scala | 8 +- .../encoders/ExpressionEncoderSuite.scala | 2 +- .../catalyst/encoders/FlatEncoderSuite.scala | 48 ++ .../encoders/ProductEncoderSuite.scala | 76 +++ .../scala/org/apache/spark/sql/Dataset.scala | 48 +- .../org/apache/spark/sql/GroupedDataset.scala | 42 +- .../org/apache/spark/sql/SQLImplicits.scala | 20 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 27 +- .../org/apache/spark/sql/QueryTest.scala | 6 +- 17 files changed, 784 insertions(+), 216 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b8a8abd02d67..ce5330be4a4e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -75,7 +75,7 @@ trait ScalaReflection { * * @see SPARK-5281 */ - private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 005c0627f56b8..5af55a3d58599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -61,34 +61,89 @@ object ExpressionEncoder { /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an - * N-tuple. Note that these encoders should first be bound correctly to the combined input - * schema. + * N-tuple. Note that these encoders should be unresolved before calling this method. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - val schema = - StructType( - encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + assert(encoders.length > 1) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + val extractExpressions = encoders.map { case e if e.flat => e.extractExpressions.head case other => CreateStruct(other.extractExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t: ObjectType, _) => + Invoke( + BoundReference(0, ObjectType(cls), true), + s"_${index + 1}", + t) + } } + + val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.constructExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + enc.constructExpression.transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue( + BoundReference(index, enc.schema, true), + Literal(nameParts.head)) + + case BoundReference(ordinal, dataType, _) => + GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dataType) + } + } + } + val constructExpression = - NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + NewInstance(cls, constructExpressions, false, ObjectType(cls)) new ExpressionEncoder[Any]( schema, - false, + flat = false, extractExpressions, constructExpression, ClassTag.apply(cls)) } - /** A helper for producing encoders of Tuple2 from other encoders. */ def tuple[T1, T2]( - e1: ExpressionEncoder[T1], - e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = - tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] + enc1: ExpressionEncoder[T1], + enc2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = { + tuple(Seq(enc1, enc2)).asInstanceOf[ExpressionEncoder[(T1, T2)]] + } + + def tuple[T1, T2, T3]( + enc1: ExpressionEncoder[T1], + enc2: ExpressionEncoder[T2], + enc3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] = { + tuple(Seq(enc1, enc2, enc3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + } + + def tuple[T1, T2, T3, T4]( + enc1: ExpressionEncoder[T1], + enc2: ExpressionEncoder[T2], + enc3: ExpressionEncoder[T3], + enc4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] = { + tuple(Seq(enc1, enc2, enc3, enc4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + } + + def tuple[T1, T2, T3, T4, T5]( + enc1: ExpressionEncoder[T1], + enc2: ExpressionEncoder[T2], + enc3: ExpressionEncoder[T3], + enc4: ExpressionEncoder[T4], + enc5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = { + tuple(Seq(enc1, enc2, enc3, enc4, enc5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + } } /** @@ -128,7 +183,7 @@ case class ExpressionEncoder[T]( /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * you must `bind` an encoder to a specific schema before you can call this * function. */ def fromRow(row: InternalRow): T = try { @@ -138,51 +193,14 @@ case class ExpressionEncoder[T]( throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) } - /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. - */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(constructExpression = analyzedPlan.expressions.head.children.head) - } - /** * Returns a copy of this encoder where the expressions used to construct an object from an input - * row have been bound to the ordinals of the given schema. Note that you need to first call - * resolve before bind. + * row have been bound to the ordinals of the given schema. */ - def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) - } - - /** - * Replaces any bound references in the schema with the attributes at the corresponding ordinal - * in the provided schema. This can be used to "relocate" a given encoder to pull values from - * a different schema than it was initially bound to. It can also be used to assign attributes - * to ordinal based extraction (i.e. because the input data was a tuple). - */ - def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) - copy(constructExpression = constructExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) - }) - } - - /** - * Given an encoder that has already been bound to a given schema, returns a new encoder - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were originally part of a larger - * row, but now you have projected out only the key expressions. - */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) + def bind(attrs: Seq[Attribute] = schema.toAttributes): ExpressionEncoder[T] = { + val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(attrs)) + val resolved = SimpleAnalyzer.execute(plan).expressions.head.children.head + copy(constructExpression = BindReferences.bindReference(resolved, attrs)) } def shift(delta: Int): ExpressionEncoder[T] = { @@ -191,23 +209,6 @@ case class ExpressionEncoder[T]( }) } - /** - * Returns a copy of this encoder where the expressions used to create an object given an - * input row have been modified to pull the object out from a nested struct, instead of the - * top level fields. - */ - def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { - copy(constructExpression = constructExpression transform { - case u: Attribute if u != input => - UnresolvedExtractValue(input, Literal(u.name)) - case b: BoundReference if b != input => - GetStructField( - input, - StructField(s"i[${b.ordinal}]", b.dataType), - b.ordinal) - }) - } - protected val attrs = extractExpressions.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala new file mode 100644 index 0000000000000..9a0830ae70244 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} +import org.apache.spark.sql.catalyst.ScalaReflection + +object FlatEncoder { + import ScalaReflection.schemaFor + import ScalaReflection.dataTypeFor + + def apply[T : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) + + val input = BoundReference(0, dataTypeFor(tpe), nullable = true) + val extractExpression = CreateNamedStruct( + Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) + val constructExpression = ProductEncoder.constructorFor(tpe) + + new ExpressionEncoder[T]( + extractExpression.dataType, + flat = true, + extractExpression.flatten, + constructExpression, + ClassTag[T](cls)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala new file mode 100644 index 0000000000000..adafeacd82b15 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -0,0 +1,463 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.util.Utils +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, ArrayData, GenericArrayData} + +import scala.reflect.ClassTag + +object ProductEncoder { + import ScalaReflection.universe._ + import ScalaReflection.localTypeOf + import ScalaReflection.dataTypeFor + import ScalaReflection.Schema + import ScalaReflection.schemaFor + import ScalaReflection.arrayClassFor + + def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val extractExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] + val constructExpression = constructorFor(tpe) + + new ExpressionEncoder[T]( + extractExpression.dataType, + flat = false, + extractExpression.flatten, + constructExpression, + ClassTag[T](cls)) + } + + def extractorFor( + inputObject: Expression, + tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + optType match { + // For primitive types we must manually unbox the value of the object. + case t if t <:< definitions.IntTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), + "intValue", + IntegerType) + case t if t <:< definitions.LongTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), + "longValue", + LongType) + case t if t <:< definitions.DoubleTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), + "doubleValue", + DoubleType) + case t if t <:< definitions.FloatTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), + "floatValue", + FloatType) + case t if t <:< definitions.ShortTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), + "shortValue", + ShortType) + case t if t <:< definitions.ByteTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), + "byteValue", + ByteType) + case t if t <:< definitions.BooleanTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), + "booleanValue", + BooleanType) + + // For non-primitives, we can just extract the object from the Option and then recurse. + case other => + val className: String = optType.erasure.typeSymbol.asClass.fullName + val classObj = Utils.classForName(className) + val optionObjectType = ObjectType(classObj) + + val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, schemaFor(optType).dataType), + extractorFor(unwrapped, optType)) + } + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + CreateNamedStruct(params.head.flatMap { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + if (RowEncoder.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), inputObject, elementDataType) + } + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + if (RowEncoder.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), inputObject, elementDataType) + } + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + + val rawMap = inputObject + val keys = + NewInstance( + classOf[GenericArrayData], + Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, + dataType = ObjectType(classOf[ArrayData])) + val values = + NewInstance( + classOf[GenericArrayData], + Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, + dataType = ObjectType(classOf[ArrayData])) + NewInstance( + classOf[ArrayBasedMapData], + keys :: values :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case other => + throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + } + } + } + + def constructorFor( + tpe: `Type`, + path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { + + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal(ordinal: Int, dataType: DataType) = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) + + /** Returns the current path or `BoundReference`. */ + def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + + tpe match { + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + WrapOption(null, constructorFor(optType, path)) + + case t if t <:< localTypeOf[java.lang.Integer] => + val boxedType = classOf[java.lang.Integer] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Long] => + val boxedType = classOf[java.lang.Long] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Double] => + val boxedType = classOf[java.lang.Double] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Float] => + val boxedType = classOf[java.lang.Float] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Short] => + val boxedType = classOf[java.lang.Short] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Byte] => + val boxedType = classOf[java.lang.Byte] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Boolean] => + val boxedType = classOf[java.lang.Boolean] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + arrayClassFor(elementType)) + } + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + val arrayData = primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + arrayClassFor(elementType)) + } + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + val className: String = t.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + + val arguments = params.head.zipWithIndex.map { case (p, i) => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val dataType = schemaFor(fieldType).dataType + + // For tuples, we based grab the inner fields by ordinal instead of name. + if (className startsWith "scala.Tuple") { + constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + } else { + constructorFor(fieldType, Some(addToPath(fieldName))) + } + } + + val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0b42130a013b2..e0be896bb3548 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -119,9 +119,17 @@ object RowEncoder { CreateStruct(convertedFields) } - private def externalDataTypeFor(dt: DataType): DataType = dt match { + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => dt + FloatType | DoubleType | BinaryType => true + case _ => false + } + + private def externalDataTypeFor(dt: DataType): DataType = dt match { + case _ if isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -137,13 +145,13 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) + constructorFor(BoundReference(i, f.dataType, f.nullable)) ) } CreateExternalRow(fields) } - private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { + private def constructorFor(input: Expression): Expression = input.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input @@ -170,7 +178,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_, et), input, et), + MapObjects(constructorFor, input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( @@ -181,10 +189,10 @@ object RowEncoder { case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType) + val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType) + val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( ArrayBasedMapData, @@ -197,42 +205,8 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(getField(input, i, f.dataType), f.dataType)) + constructorFor(GetInternalRowField(input, i, f.dataType))) } CreateExternalRow(convertedFields) } - - private def getField( - row: Expression, - ordinal: Int, - dataType: DataType): Expression = dataType match { - case BooleanType => - Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil) - case ByteType => - Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil) - case ShortType => - Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil) - case IntegerType | DateType => - Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil) - case LongType | TimestampType => - Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil) - case FloatType => - Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil) - case DoubleType => - Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil) - case t: DecimalType => - Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_))) - case StringType => - Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil) - case BinaryType => - Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil) - case CalendarIntervalType => - Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil) - case t: StructType => - Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil) - case _: ArrayType => - Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil) - case _: MapType => - Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4f58464221b4b..f9e7cb81ddd60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -288,14 +288,13 @@ case class WrapOption(optionType: DataType, child: Expression) throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val javaType = ctx.javaType(optionType) val inputObject = child.gen(ctx) s""" ${inputObject.code} boolean ${ev.isNull} = false; - scala.Option<$javaType> ${ev.value} = + scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 764f8aaebddf1..e7c4bcd3167ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -467,10 +467,10 @@ case class MapPartitions[T, U]( } /** Factory for constructing new `AppendColumn` nodes. */ -object AppendColumn { - def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { +object AppendColumns { + def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumns[T, U] = { val attrs = encoderFor[U].schema.toAttributes - new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) + new AppendColumns[T, U](func, encoderFor[T], encoderFor[U], attrs, child) } } @@ -479,7 +479,7 @@ object AppendColumn { * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to * decode/encode from the JVM object representation expected by `func.` */ -case class AppendColumn[T, U]( +case class AppendColumns[T, U]( func: T => U, tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index b0dacf7f555e0..d8b2dc5de20d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -237,7 +237,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { } val convertedData = encoder.toRow(inputData) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.bind() val convertedBack = try boundEncoder.fromRow(convertedData) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala new file mode 100644 index 0000000000000..b2a02d8619456 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite + +class FlatEncoderSuite extends SparkFunSuite { + test("int") { + val enc = FlatEncoder[Int] + val row = enc.toRow(3) + val back = enc.fromRow(row) + assert(3 === back) + } + + test("string") { + val enc = FlatEncoder[String] + val row = enc.toRow("abc") + val back = enc.fromRow(row) + assert("abc" === back) + } + + test("seq") { + val enc = FlatEncoder[Seq[String]] + val row = enc.toRow(Seq("abc", "xzy")) + val back = enc.fromRow(row) + assert(back.length == 2) + Seq("abc", "xzy").zip(back).foreach { + case (a, b) => assert(a == b) + } + // not sure why `===` doesn't work... + //assert(Seq("abc", "xyz") === back) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala new file mode 100644 index 0000000000000..d9738671eb252 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.PrimitiveData + +class ProductEncoderSuite extends SparkFunSuite { + test("primitive data") { + val input = PrimitiveData(1, 2L, 3.1, 4.1f, 5, 6, true) + val encoder = ProductEncoder[PrimitiveData] + val boundEnc = encoder.bind() + val row = boundEnc.toRow(input) + val output = boundEnc.fromRow(row) + assert(input == output) + } + + test("boxed data") { + val input = BoxedData(1, 2L, 3.1, 4.1f, 5.asInstanceOf[Short], 6.asInstanceOf[Byte], true) + val encoder = ProductEncoder[BoxedData] + val boundEnc = encoder.bind() + val row = boundEnc.toRow(input) + val output = boundEnc.fromRow(row) + assert(input == output) + } + + test("repeated struct") { + val input = RepeatedStruct(Seq( + PrimitiveData(1, 2L, 3.1, 4.1f, 5, 6, true), + PrimitiveData(2, 3L, 4.1, 5.1f, 6, 7, false))) + val encoder = ProductEncoder[RepeatedStruct] + val boundEnc = encoder.bind() + val row = boundEnc.toRow(input) + val output = boundEnc.fromRow(row) + assert(input == output) + } +} + +object ProductEncoderSuite { + case class RepeatedStruct(s: Seq[PrimitiveData]) + + case class NestedArray(a: Array[Array[Int]]) + + case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + + case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + + case class SpecificCollection(l: List[Int]) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7e5ab19bf846..a1d07486927fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,15 +62,11 @@ import org.apache.spark.sql.types.StructType class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, - unresolvedEncoder: Encoder[T]) extends Queryable with Serializable { + encoder: Encoder[T]) extends Queryable with Serializable { - /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { - case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output) - case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") - } + implicit val enc: ExpressionEncoder[T] = encoderFor(encoder) - private implicit def classTag = encoder.clsTag + private implicit def classTag = enc.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) @@ -80,7 +76,7 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def schema: StructType = encoder.schema + def schema: StructType = enc.schema /* ************* * * Conversions * @@ -133,10 +129,8 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = encoderFor[T] - val input = queryExecution.analyzed.output + val bound = enc.bind(queryExecution.analyzed.output) queryExecution.toRdd.mapPartitions { iter => - val bound = tEnc.bind(input) iter.map(bound.fromRow) } } @@ -294,12 +288,10 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = queryExecution.analyzed - val withGroupingKey = AppendColumn(func, inputPlan) + val withGroupingKey = AppendColumns(func, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( - encoderFor[K].resolve(withGroupingKey.newColumns), - encoderFor[T].bind(inputPlan.output), executed, inputPlan.output, withGroupingKey.newColumns) @@ -319,11 +311,9 @@ class Dataset[T] private[sql]( val keyAttributes = executed.analyzed.output.takeRight(cols.size) new GroupedDataset( - RowEncoder(keyAttributes.toStructType), - encoderFor[T], executed, dataAttributes, - keyAttributes) + keyAttributes)(RowEncoder(keyAttributes.toStructType), enc) } /** @@ -371,13 +361,7 @@ class Dataset[T] private[sql]( val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val unresolvedPlan = Project(aliases, logicalPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan) - // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a.toAttribute).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output) - } + val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } @@ -485,28 +469,20 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan - val leftData = this.encoder match { + val leftData = this.enc match { case e if e.flat => Alias(left.output.head, "_1")() case _ => Alias(CreateStruct(left.output), "_1")() } - val rightData = other.encoder match { + val rightData = other.enc match { case e if e.flat => Alias(right.output.head, "_2")() case _ => Alias(CreateStruct(right.output), "_2")() } - val leftEncoder = - if (encoder.flat) encoder else encoder.nested(leftData.toAttribute) - val rightEncoder = - if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple( - leftEncoder, - rightEncoder.rebind(right.output, left.output ++ right.output)) withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, Join(left, right, Inner, Some(condition.expr))) - } + }(ExpressionEncoder.tuple(enc, other.enc)) } /* ************************** * @@ -568,7 +544,7 @@ class Dataset[T] private[sql]( private[sql] def logicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder) + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), enc) private[sql] def withPlan[R : Encoder]( other: Dataset[_])( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index db61499229284..eeed97f6d615e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import java.util.{Iterator => JIterator} - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -43,27 +41,14 @@ import org.apache.spark.sql.execution.QueryExecution * after Spark 1.6. */ @Experimental -class GroupedDataset[K, T] private[sql]( - private val kEncoder: Encoder[K], - private val tEncoder: Encoder[T], +class GroupedDataset[K : Encoder, T : Encoder] private[sql]( queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } - - private implicit val tEnc = tEncoder match { - case e: ExpressionEncoder[T] => e.resolve(dataAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } + private implicit val kEnc = encoderFor[K] - /** Encoders for built in aggregations. */ - private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + private implicit val tEnc = encoderFor[T] private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -78,8 +63,6 @@ class GroupedDataset[K, T] private[sql]( */ def asKey[L : Encoder]: GroupedDataset[L, T] = new GroupedDataset( - encoderFor[L], - tEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -144,7 +127,6 @@ class GroupedDataset[K, T] private[sql]( * Internal helper function for building typed aggregations that return tuples. For simplicity * and code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. - * TODO: does not handle aggrecations that return nonflat results, */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val aliases = (groupingAttributes ++ columns.map(_.expr)).map { @@ -166,18 +148,10 @@ class GroupedDataset[K, T] private[sql]( val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) - // Rebind the encoders to the nested schema that will be produced by the aggregation. - val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a :: Nil).resolve(execution.analyzed.output) - } - new Dataset( sqlContext, execution, - ExpressionEncoder.tuple(encoders)) + ExpressionEncoder.tuple(columnEncoders)) } /** @@ -219,7 +193,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long]) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will @@ -230,7 +204,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.tEncoder + implicit def uEnc: Encoder[U] = other.tEnc new Dataset[R]( sqlContext, CoGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6da46a5f7ef9a..8552a78c8f7a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,16 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]() - - implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true) - implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) - implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true) - implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true) - implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true) - implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true) - implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) - implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + + implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] + implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] + implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] + implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] + implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] + implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] + implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] + implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(rdd)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d65cb1bae7fb5..13cbcbfbe4c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -344,7 +344,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, tEnc, uEnc, output, child) => execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil - case logical.AppendColumn(f, tEnc, uEnc, newCol, child) => + case logical.AppendColumns(f, tEnc, uEnc, newCol, child) => execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 0e5bc1f9abf28..4dac7fdb8057a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -97,7 +97,7 @@ case class TypedAggregateExpression( }) val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + lazy val boundB = bEncoder.bind(bAttributes) private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 303d636164adb..e9da7e2f7e779 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -320,9 +320,10 @@ case class MapPartitions[T, U]( child: SparkPlan) extends UnaryNode { override protected def doExecute(): RDD[InternalRow] = { + val tBound = tEncoder.bind(child.output) child.execute().mapPartitions { iter => - val tBoundEncoder = tEncoder.bind(child.output) - func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) + // `uEncoder` is only used to encode a user object to row, thus don't need to bind. + func(iter.map(tBound.fromRow)).map(uEncoder.toRow) } } } @@ -344,7 +345,8 @@ case class AppendColumns[T, U]( val tBoundEncoder = tEncoder.bind(child.output) val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) iter.map { row => - val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) + // `uEncoder` is only used to encode a user object to row, thus don't need to bind. + val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow } } @@ -374,12 +376,14 @@ case class MapGroups[K, T, U]( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val groupKeyEncoder = kEncoder.bind(groupingAttributes) + val kBound = kEncoder.bind(groupingAttributes) + val tBound = tEncoder.bind(child.output) grouped.flatMap { case (key, rowIter) => val result = func( - groupKeyEncoder.fromRow(key), - rowIter.map(tEncoder.fromRow)) + kBound.fromRow(key), + rowIter.map(tBound.fromRow)) + // `uEncoder` is only used to encode a user object to row, thus don't need to bind. result.map(uEncoder.toRow) } } @@ -413,14 +417,17 @@ case class CoGroup[K, Left, Right, R]( left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val groupKeyEncoder = kEncoder.bind(leftGroup) + val kBound = kEncoder.bind(leftGroup) + val leftBound = leftEnc.bind(left.output) + val rightBound = rightEnc.bind(right.output) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => val result = func( - groupKeyEncoder.fromRow(key), - leftResult.map(leftEnc.fromRow), - rightResult.map(rightEnc.fromRow)) + kBound.fromRow(key), + leftResult.map(leftBound.fromRow), + rightResult.map(rightBound.fromRow)) + // `rEncoder` is only used to encode a user object to row, thus don't need to bind. result.map(rEncoder.toRow) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3c174efe73ffe..8f1a56ab9719e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -83,8 +83,8 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.encoder} - |${ds.encoder.constructExpression.treeString} + |${ds.enc} + |${ds.enc.constructExpression.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -94,7 +94,7 @@ abstract class QueryTest extends PlanTest { s"""Decoded objects do not match expected objects: |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted} |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted} - |${ds.encoder.constructExpression.treeString} + |${ds.enc.constructExpression.treeString} """.stripMargin) } }