From cf53a434b893293041f73414f50d7f0918a01d49 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 May 2016 17:49:27 +0800 Subject: [PATCH 1/9] Avoid extra Project when DeserializeToObject outputs an unsupported class for Project. --- .../spark/sql/catalyst/ScalaReflection.scala | 19 +++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 17 ++++++++++++----- .../EliminateSerializationSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++ 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d158a64a85bc0..b68b48b756b84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -62,6 +62,16 @@ object ScalaReflection extends ScalaReflection { */ def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) + /** + * Returns the Spark SQL DataType for a given runtime class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. Special handling is also used for Arrays including + * those that hold primitive types. + * + * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers + */ + def dataTypeFor(t: Class[_]): DataType = dataTypeFor(getTypeForClass(t)) + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { tpe match { case t if t <:< definitions.IntTpe => IntegerType @@ -659,6 +669,15 @@ object ScalaReflection extends ScalaReflection { constructParams(t).map(_.name.toString) } + /** + * Returns the type for given runtime class. + */ + def getTypeForClass(cls: Class[_]): Type = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + classSymbol.selfType + } + /* * Retrieves the runtime class corresponding to the provided type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a3ab89dc71145..71828739a9590 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,7 +21,7 @@ import scala.annotation.tailrec import scala.collection.immutable.HashSet import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -163,14 +163,21 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - // A workaround for SPARK-14803. Remove this after it is fixed. - if (d.outputObjectType.isInstanceOf[ObjectType] && - d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) { - s.child + val addProject = if (d.outputObjectType.isInstanceOf[ObjectType]) { + ScalaReflection.dataTypeFor(d.outputObjectType.asInstanceOf[ObjectType].cls) match { + case o: ObjectType => false + case _ => true + } } else { + true + } + + if (addProject) { // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) Project(objAttr :: Nil, s.child) + } else { + s.child } case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 3c033ddc374cf..79835edf3f854 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -42,7 +42,7 @@ class EliminateSerializationSuite extends PlanTest { val input = LocalRelation('obj.obj(classOf[(Int, Int)])) val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze val optimized = Optimize.execute(plan) - val expected = input.select('obj.as("obj")).analyze + val expected = input comparePlans(optimized, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 68a12b062249e..0cfbb3e1caae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -658,8 +658,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val dataset = Seq(1, 2, 3).toDS() checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } + + test("dataset.rdd with generic case class") { + val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS + val ds2 = ds.map(g => Generic(g.id, g.value)) + ds.rdd.map(r => r.id).count + ds2.rdd.map(r => r.id).count + + val ds3 = ds.map(g => new java.lang.Long(g.id)) + ds3.rdd.map(r => r).count + } } +case class Generic[T](id: T, value: Double) + case class OtherTuple(_1: String, _2: Int) case class TupleClass(data: (Int, String)) From 48e6b6d3bc4d41d808db43b888e6b17a17a77d1f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 May 2016 07:38:06 +0000 Subject: [PATCH 2/9] Add ObjectProject. --- .../sql/catalyst/optimizer/Optimizer.scala | 14 +++++----- .../sql/catalyst/plans/logical/object.scala | 7 +++++ .../EliminateSerializationSuite.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 ++ .../apache/spark/sql/execution/objects.scala | 26 +++++++++++++++++++ 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 71828739a9590..d96ecb69c7e08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -163,21 +163,21 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - val addProject = if (d.outputObjectType.isInstanceOf[ObjectType]) { + val outputObject = if (d.outputObjectType.isInstanceOf[ObjectType]) { ScalaReflection.dataTypeFor(d.outputObjectType.asInstanceOf[ObjectType].cls) match { - case o: ObjectType => false - case _ => true + case o: ObjectType => true + case _ => false } } else { - true + false } - if (addProject) { + if (outputObject) { + ObjectProject(d.output.head, s.child) + } else { // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) Project(objAttr :: Nil, s.child) - } else { - s.child } case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 84339f439a666..fa01a460fca01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -80,6 +80,13 @@ trait ObjectConsumer extends UnaryNode { def inputObjectType: DataType = child.output.head.dataType } +/** + * Takes the object from child and projects it as new attribute. + */ +case class ObjectProject( + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer + /** * Takes the input row from child and turns it into object using the given deserializer expression. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 79835edf3f854..19037e160aa67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -42,7 +42,7 @@ class EliminateSerializationSuite extends PlanTest { val input = LocalRelation('obj.obj(classOf[(Int, Int)])) val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze val optimized = Optimize.execute(plan) - val expected = input + val expected = ObjectProject(input.output.head.withNullability(false), input) comparePlans(optimized, expected) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9747e58f43717..558fa9a3a0ce5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -302,6 +302,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") + case logical.ObjectProject(objAttr, child) => + execution.ObjectProject(objAttr, planLater(child)) :: Nil case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil case logical.SerializeFromObject(serializer, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 56a39069511d7..a0eda51808535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -27,6 +27,32 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.{DataType, ObjectType} +/** + * Simply takes the object from child and projects it as new attribute. + * The output of this operator is a single-field safe row containing the input object. + */ +case class ObjectProject( + outputObjAttr: Attribute, + child: SparkPlan) extends UnaryExecNode with CodegenSupport { + + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + consume(ctx, input) + } + + override protected def doExecute(): RDD[InternalRow] = child.execute() +} + /** * Takes the input row from child and turns it into object using the given deserializer expression. * The output of this operator is a single-field safe row containing the deserialized object. From 737c5187f7e9db1a08407246416cbc967fec7d30 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 May 2016 09:36:52 +0000 Subject: [PATCH 3/9] Fix test. --- .../optimizer/TypedFilterOptimizationSuite.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 1fae64e3bc6b1..00ad535d2223f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ObjectProject} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.BooleanType @@ -47,10 +47,9 @@ class TypedFilterOptimizationSuite extends PlanTest { val query = input.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) - - val expected = input.deserialize[(Int, Int)] - .where(callFunction(f1, BooleanType, 'obj)) - .select('obj.as("obj")) + val deserialized = input.deserialize[(Int, Int)] + val expected = ObjectProject(deserialized.output.head, deserialized + .where(callFunction(f1, BooleanType, 'obj))) .where(callFunction(f2, BooleanType, 'obj)) .serialize[(Int, Int)].analyze From 4b0773adcad8b6d6f0be6f5ca2287fb877e502c9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 May 2016 07:07:30 +0000 Subject: [PATCH 4/9] Replace projection that is used to preserve expr id in eliminate serialization. --- .../spark/sql/catalyst/ScalaReflection.scala | 19 -------- .../sql/catalyst/optimizer/Optimizer.scala | 46 +++++++++++-------- .../sql/catalyst/plans/logical/object.scala | 2 + .../spark/sql/execution/SparkStrategies.scala | 2 - .../apache/spark/sql/execution/objects.scala | 26 ----------- 5 files changed, 30 insertions(+), 65 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index b68b48b756b84..d158a64a85bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -62,16 +62,6 @@ object ScalaReflection extends ScalaReflection { */ def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - /** - * Returns the Spark SQL DataType for a given runtime class. Where this is not an exact mapping - * to a native type, an ObjectType is returned. Special handling is also used for Arrays including - * those that hold primitive types. - * - * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type - * system. As a result, ObjectType will be returned for things like boxed Integers - */ - def dataTypeFor(t: Class[_]): DataType = dataTypeFor(getTypeForClass(t)) - private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { tpe match { case t if t <:< definitions.IntTpe => IntegerType @@ -669,15 +659,6 @@ object ScalaReflection extends ScalaReflection { constructParams(t).map(_.name.toString) } - /** - * Returns the type for given runtime class. - */ - def getTypeForClass(cls: Class[_]): Type = { - val m = runtimeMirror(cls.getClassLoader) - val classSymbol = m.staticClass(cls.getName) - classSymbol.selfType - } - /* * Retrieves the runtime class corresponding to the provided type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d96ecb69c7e08..f1482924c1671 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,7 +21,7 @@ import scala.annotation.tailrec import scala.collection.immutable.HashSet import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -102,7 +102,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCasts, SimplifyCaseConversionExpressions, RewriteCorrelatedScalarSubquery, - EliminateSerialization) :: + EliminateSerialization, + RemoveExtraProjectForSerialization) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, @@ -155,6 +156,28 @@ object SamplePushDown extends Rule[LogicalPlan] { } } +/** + * Removes extra Project added in EliminateSerialization rule. + */ +object RemoveExtraProjectForSerialization extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val objectProject = plan.find(_.isInstanceOf[ObjectProject]).map { case o: ObjectProject => + val replaceFrom = o.outputObjAttr + val replaceTo = o.child.output.head + plan.transformAllExpressions { + case a: Attribute if a.equals(replaceFrom) => replaceTo + }.transform { + case op: ObjectProject if o == op => op.child + } + } + if (objectProject.isDefined) { + objectProject.get + } else { + plan + } + } +} + /** * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) * representation of data item. For example back to back map operations. @@ -163,22 +186,9 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - val outputObject = if (d.outputObjectType.isInstanceOf[ObjectType]) { - ScalaReflection.dataTypeFor(d.outputObjectType.asInstanceOf[ObjectType].cls) match { - case o: ObjectType => true - case _ => false - } - } else { - false - } - - if (outputObject) { - ObjectProject(d.output.head, s.child) - } else { - // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. - val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) - Project(objAttr :: Nil, s.child) - } + // Adds an extra ObjectProject here, to preserve the output expr id of `DeserializeToObject`. + // We will remove it later. + ObjectProject(d.output.head, s.child) case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index fa01a460fca01..0e0a6d9ede39f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -82,6 +82,8 @@ trait ObjectConsumer extends UnaryNode { /** * Takes the object from child and projects it as new attribute. + * This logical plan is just used to preserve expr id temporarily and will be removed before + * the end of optimization phase. */ case class ObjectProject( outputObjAttr: Attribute, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 558fa9a3a0ce5..9747e58f43717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -302,8 +302,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") - case logical.ObjectProject(objAttr, child) => - execution.ObjectProject(objAttr, planLater(child)) :: Nil case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil case logical.SerializeFromObject(serializer, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 5ba24c4fe1768..bafbbdf65724d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -27,32 +27,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.{DataType, ObjectType} -/** - * Simply takes the object from child and projects it as new attribute. - * The output of this operator is a single-field safe row containing the input object. - */ -case class ObjectProject( - outputObjAttr: Attribute, - child: SparkPlan) extends UnaryExecNode with CodegenSupport { - - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].inputRDDs() - } - - protected override def doProduce(ctx: CodegenContext): String = { - child.asInstanceOf[CodegenSupport].produce(ctx, this) - } - - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - consume(ctx, input) - } - - override protected def doExecute(): RDD[InternalRow] = child.execute() -} - /** * Takes the input row from child and turns it into object using the given deserializer expression. * The output of this operator is a single-field safe row containing the deserialized object. From 29a0c70488fc8d3f7157679d8f41f7ceb5af9bc4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 11 May 2016 05:13:01 +0000 Subject: [PATCH 5/9] Address comment. --- .../sql/catalyst/optimizer/Optimizer.scala | 58 ++++++++++++++----- .../sql/catalyst/plans/logical/object.scala | 9 --- .../EliminateSerializationSuite.scala | 2 +- .../TypedFilterOptimizationSuite.scala | 9 +-- 4 files changed, 51 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f1482924c1671..bcdb0f0959ec8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -103,7 +103,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCaseConversionExpressions, RewriteCorrelatedScalarSubquery, EliminateSerialization, - RemoveExtraProjectForSerialization) :: + RemoveAliasOnlyProject) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, @@ -157,21 +157,52 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Removes extra Project added in EliminateSerialization rule. + * Removes the Project only conducting Alias of its child node. + * It is created mainly for removing extra Project added in EliminateSerialization rule, + * but can also benefit other operators. */ -object RemoveExtraProjectForSerialization extends Rule[LogicalPlan] { +object RemoveAliasOnlyProject extends Rule[LogicalPlan] { + // Check if projectList in the Project node has the same attribute names and ordering + // as its child node. + private def checkAliasOnly( + projectList: Seq[NamedExpression], + childOutput: Seq[Attribute]): Boolean = { + if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) { + return false + } else { + projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) => + a.child match { + case attr: Attribute + if a.name == attr.name && attr.name == o.name && attr.dataType == o.dataType + && attr.exprId == o.exprId => + true + case _ => false + } + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = { - val objectProject = plan.find(_.isInstanceOf[ObjectProject]).map { case o: ObjectProject => - val replaceFrom = o.outputObjAttr - val replaceTo = o.child.output.head + val processedPlan = plan.find { p => + p match { + case Project(pList, child) if checkAliasOnly(pList, child.output) => true + case _ => false + } + }.map { case p: Project => + val attrMap = p.projectList.map { a => + val alias = a.asInstanceOf[Alias] + val replaceFrom = alias.toAttribute + val replaceTo = alias.child.asInstanceOf[Attribute] + (replaceFrom, replaceTo) + }.toMap plan.transformAllExpressions { - case a: Attribute if a.equals(replaceFrom) => replaceTo + case a: Attribute if attrMap.contains(a) => attrMap(a) }.transform { - case op: ObjectProject if o == op => op.child + case op: Project if op == p => op.child } } - if (objectProject.isDefined) { - objectProject.get + if (processedPlan.isDefined) { + processedPlan.get } else { plan } @@ -186,9 +217,10 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - // Adds an extra ObjectProject here, to preserve the output expr id of `DeserializeToObject`. - // We will remove it later. - ObjectProject(d.output.head, s.child) + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0e0a6d9ede39f..84339f439a666 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -80,15 +80,6 @@ trait ObjectConsumer extends UnaryNode { def inputObjectType: DataType = child.output.head.dataType } -/** - * Takes the object from child and projects it as new attribute. - * This logical plan is just used to preserve expr id temporarily and will be removed before - * the end of optimization phase. - */ -case class ObjectProject( - outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectProducer - /** * Takes the input row from child and turns it into object using the given deserializer expression. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 19037e160aa67..3c033ddc374cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -42,7 +42,7 @@ class EliminateSerializationSuite extends PlanTest { val input = LocalRelation('obj.obj(classOf[(Int, Int)])) val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze val optimized = Optimize.execute(plan) - val expected = ObjectProject(input.output.head.withNullability(false), input) + val expected = input.select('obj.as("obj")).analyze comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 00ad535d2223f..1fae64e3bc6b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ObjectProject} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.BooleanType @@ -47,9 +47,10 @@ class TypedFilterOptimizationSuite extends PlanTest { val query = input.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) - val deserialized = input.deserialize[(Int, Int)] - val expected = ObjectProject(deserialized.output.head, deserialized - .where(callFunction(f1, BooleanType, 'obj))) + + val expected = input.deserialize[(Int, Int)] + .where(callFunction(f1, BooleanType, 'obj)) + .select('obj.as("obj")) .where(callFunction(f2, BooleanType, 'obj)) .serialize[(Int, Int)].analyze From 85fba173b871a1c8bc24f1a781ddbf59f77db645 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 11 May 2016 08:16:20 +0000 Subject: [PATCH 6/9] Fix test. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bcdb0f0959ec8..377cce6caaddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -191,12 +191,12 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { }.map { case p: Project => val attrMap = p.projectList.map { a => val alias = a.asInstanceOf[Alias] - val replaceFrom = alias.toAttribute + val replaceFrom = alias.toAttribute.exprId val replaceTo = alias.child.asInstanceOf[Attribute] (replaceFrom, replaceTo) }.toMap plan.transformAllExpressions { - case a: Attribute if attrMap.contains(a) => attrMap(a) + case a: Attribute if attrMap.contains(a.exprId) => attrMap(a.exprId) }.transform { case op: Project if op == p => op.child } From ea553983006f744e3c6563f9e04139ad1371f65e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 11 May 2016 08:35:15 +0000 Subject: [PATCH 7/9] Address comment. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 377cce6caaddf..50225d8b37060 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -172,10 +172,7 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { } else { projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) => a.child match { - case attr: Attribute - if a.name == attr.name && attr.name == o.name && attr.dataType == o.dataType - && attr.exprId == o.exprId => - true + case attr: Attribute if a.name == attr.name && attr.semanticEquals(o) => true case _ => false } } From c3748bac348e30dc87cb41fcdb3ae9086acec66f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 12 May 2016 03:21:46 +0000 Subject: [PATCH 8/9] Address comments. --- .../sql/catalyst/optimizer/Optimizer.scala | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 50225d8b37060..ad047c45844fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -164,7 +164,7 @@ object SamplePushDown extends Rule[LogicalPlan] { object RemoveAliasOnlyProject extends Rule[LogicalPlan] { // Check if projectList in the Project node has the same attribute names and ordering // as its child node. - private def checkAliasOnly( + private def isAliasOnly( projectList: Seq[NamedExpression], childOutput: Seq[Attribute]): Boolean = { if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) { @@ -180,29 +180,22 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - val processedPlan = plan.find { p => + val aliasOnlyProject = plan.find { p => p match { - case Project(pList, child) if checkAliasOnly(pList, child.output) => true + case Project(pList, child) if isAliasOnly(pList, child.output) => true case _ => false } - }.map { case p: Project => - val attrMap = p.projectList.map { a => - val alias = a.asInstanceOf[Alias] - val replaceFrom = alias.toAttribute.exprId - val replaceTo = alias.child.asInstanceOf[Attribute] - (replaceFrom, replaceTo) - }.toMap + } + + aliasOnlyProject.map { case p: Project => + val aliases = p.projectList.map(_.asInstanceOf[Alias]) + val attrMap = AttributeMap(aliases.map(a => (a.toAttribute, a.child))) plan.transformAllExpressions { - case a: Attribute if attrMap.contains(a.exprId) => attrMap(a.exprId) + case a: Attribute if attrMap.contains(a) => attrMap(a) }.transform { - case op: Project if op == p => op.child + case op: Project if op.eq(p) => op.child } - } - if (processedPlan.isDefined) { - processedPlan.get - } else { - plan - } + }.getOrElse(plan) } } From 882fc666c1efb2d8313d5f3b944b779651045d59 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 12 May 2016 10:02:47 +0000 Subject: [PATCH 9/9] Address comments. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ad047c45844fa..a8901fb39105a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -209,7 +209,8 @@ object EliminateSerialization extends Rule[LogicalPlan] { if d.outputObjectType == s.inputObjectType => // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. // We will remove it later in RemoveAliasOnlyProject rule. - val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + val objAttr = + Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) Project(objAttr :: Nil, s.child) case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 337bef32417bf..ff69e52c1400f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -662,11 +662,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("dataset.rdd with generic case class") { val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS val ds2 = ds.map(g => Generic(g.id, g.value)) - ds.rdd.map(r => r.id).count - ds2.rdd.map(r => r.id).count + assert(ds.rdd.map(r => r.id).count === 2) + assert(ds2.rdd.map(r => r.id).count === 2) val ds3 = ds.map(g => new java.lang.Long(g.id)) - ds3.rdd.map(r => r).count + assert(ds3.rdd.map(r => r).count === 2) } test("runtime null check for RowEncoder") {