From e8b8240d389782bfc0e75cbe1797ce5aecc47092 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 18 Aug 2015 13:19:39 -0700 Subject: [PATCH 1/3] [SPARK-10093][SQL] Avoid transformation on executors. --- .../apache/spark/sql/execution/basicOperators.scala | 12 +++++++----- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 77b98064a9e16..3f68b05a24f44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -75,14 +75,16 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) + /** Rewrite the project list to use unsafe expressions as needed. */ + protected val unsafeProjectList = projectList.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - this.transformAllExpressions { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - } - val project = UnsafeProjection.create(projectList, child.output) + val project = UnsafeProjection.create(unsafeProjectList, child.output) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1e2aaae52c9e8..284fff184085a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -878,4 +878,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b") checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil) } + + test("SPARK-10093: Avoid transformations on executors") { + val df = Seq((1, 1)).toDF("a", "b") + df.where($"a" === 1) + .select($"a", $"b", struct($"b")) + .orderBy("a") + .select(struct($"b")) + .collect() + } } From 836706bfe2d82e1bc4ce0019c490196ad975b72f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 18 Aug 2015 18:37:15 -0700 Subject: [PATCH 2/3] Added DataFrameComplexTypeSuite. --- .../expressions/complexTypeCreator.scala | 8 +++- .../spark/sql/DataFrameComplexTypeSuite.scala | 43 +++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 298aee3499275..1c546719730b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -206,7 +206,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, children) @@ -244,7 +246,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala new file mode 100644 index 0000000000000..4074b9e620a97 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("UDF on struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(struct($"a").as("s")).select(f($"s.a")).collect() + } + + test("UDF on named_struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() + } + + test("UDF on array") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + } +} From a3e4758950a5e3c1f686ffecf43d24f3456caad7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 18 Aug 2015 18:38:03 -0700 Subject: [PATCH 3/3] Updated documentation. --- .../scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 4074b9e620a97..3c359dd840ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +/** + * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). + */ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._