From 7740f5cb05babb6f3aa074786688378d045e51a6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 14 Sep 2015 18:37:32 -0700 Subject: [PATCH 1/8] Rewrite FilterNodeSuite to use LocalRelations This commit refactors DummyNode to take in data from LocalRelation. Then it rewrites FilterNodeSuite to make it read from DummyNode instead of from a DataFrame. Future commits will cover other LocalNode test suites. --- .../spark/sql/execution/local/LocalNode.scala | 8 +- .../spark/sql/execution/SparkPlanTest.scala | 2 +- .../spark/sql/execution/local/DummyNode.scala | 68 +++++++++++++++++ .../sql/execution/local/FilterNodeSuite.scala | 44 +++++++---- .../sql/execution/local/LocalNodeSuite.scala | 74 +++++-------------- .../sql/execution/local/LocalNodeTest.scala | 24 +++++- 6 files changed, 140 insertions(+), 80 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index e540ef8555eb6..0f16944e36329 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { +abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { protected val codegenEnabled: Boolean = conf.codegenEnabled protected val unsafeEnabled: Boolean = conf.unsafeEnabled - lazy val schema: StructType = StructType.fromAttributes(output) - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") - def output: Seq[Attribute] - /** * Initializes the iterator state. Must be called before calling `next()`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index de45ae4635fb7..3d218f01c9ead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -238,7 +238,7 @@ object SparkPlanTest { outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { + plan transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 0000000000000..efc3227dd60d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index a12670e347c25..b3c215bdd5b58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -17,25 +17,41 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.IntegerType -class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { + +class FilterNodeSuite extends LocalNodeTest { + private val attributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) test("basic") { - val condition = (testData.col("key") % 2) === 0 - checkAnswer( - testData, - node => FilterNode(conf, condition.expr, node), - testData.filter(condition).collect() - ) + val n = 100 + val cond = 'k % 2 === 0 + val inputData = (1 to n).map { i => (i, i) }.toArray + val inputNode = new DummyNode(attributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val condition = (emptyTestData.col("key") % 2) === 0 - checkAnswer( - emptyTestData, - node => FilterNode(conf, condition.expr, node), - emptyTestData.filter(condition).collect() - ) + val cond = 'k % 2 === 0 + val inputData = Array.empty[(Int, Int)] + val inputNode = new DummyNode(attributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala index b89fa46f8b3b4..01cee7f3dbe8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -18,27 +18,30 @@ package org.apache.spark.sql.execution.local import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.types.IntegerType + class LocalNodeSuite extends SparkFunSuite { - private val data = (1 to 100).toArray + private val attributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + private val data = (1 to 100).map { i => (i, i) }.toArray test("basic open, next, fetch, close") { - val node = new DummyLocalNode(data) + val node = new DummyNode(attributes, data) assert(!node.isOpen) node.open() assert(node.isOpen) - data.foreach { i => + data.foreach { case (k, v) => assert(node.next()) // fetch should be idempotent val fetched = node.fetch() assert(node.fetch() === fetched) assert(node.fetch() === fetched) - assert(node.fetch().numFields === 1) - assert(node.fetch().getInt(0) === i) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) } assert(!node.next()) node.close() @@ -46,16 +49,17 @@ class LocalNodeSuite extends SparkFunSuite { } test("asIterator") { - val node = new DummyLocalNode(data) + val node = new DummyNode(attributes, data) val iter = node.asIterator node.open() - data.foreach { i => + data.foreach { case (k, v) => // hasNext should be idempotent assert(iter.hasNext) assert(iter.hasNext) val item = iter.next() - assert(item.numFields === 1) - assert(item.getInt(0) === i) + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) } intercept[NoSuchElementException] { iter.next() @@ -64,53 +68,13 @@ class LocalNodeSuite extends SparkFunSuite { } test("collect") { - val node = new DummyLocalNode(data) + val node = new DummyNode(attributes, data) node.open() val collected = node.collect() assert(collected.size === data.size) - assert(collected.forall(_.size === 1)) - assert(collected.map(_.getInt(0)) === data) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) node.close() } } - -/** - * A dummy [[LocalNode]] that just returns one row per integer in the input. - */ -private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { - private var index = Int.MinValue - - def this(input: Array[Int]) { - this(new SQLConf, input) - } - - def isOpen: Boolean = { - index != Int.MinValue - } - - override def output: Seq[Attribute] = { - Seq(AttributeReference("something", IntegerType)()) - } - - override def children: Seq[LocalNode] = Seq.empty - - override def open(): Unit = { - index = -1 - } - - override def next(): Boolean = { - index += 1 - index < input.size - } - - override def fetch(): InternalRow = { - assert(index >= 0 && index < input.size) - val values = Array(input(index).asInstanceOf[Any]) - new GenericInternalRow(values) - } - - override def close(): Unit = { - index = Int.MinValue - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index b95d4ea7f8f2a..65c4589377063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -22,6 +22,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute class LocalNodeTest extends SparkFunSuite with SharedSQLContext { @@ -99,6 +100,21 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { df.queryExecution.toRdd.map(_.copy()).collect()) } + /** + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. + */ + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + } /** @@ -116,10 +132,10 @@ object LocalNodeTest { * to being compared. */ def checkAnswer( - input: Seq[SeqScanNode], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { + input: Seq[SeqScanNode], + nodeFunction: Seq[LocalNode] => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Option[String] = { val outputNode = nodeFunction(input) From 10fc10972300aa6fcae9110e0b96055611a606f0 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 14 Sep 2015 23:42:16 -0700 Subject: [PATCH 2/8] Intersect, Project, and Limit --- .../sql/execution/local/FilterNodeSuite.scala | 26 ++++--------- .../execution/local/IntersectNodeSuite.scala | 24 ++++++------ .../sql/execution/local/LimitNodeSuite.scala | 28 ++++++------- .../sql/execution/local/LocalNodeSuite.scala | 15 ++----- .../sql/execution/local/LocalNodeTest.scala | 6 +++ .../execution/local/ProjectNodeSuite.scala | 39 +++++++++++-------- 6 files changed, 67 insertions(+), 71 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index b3c215bdd5b58..4eadce646d379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -18,20 +18,13 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.types.IntegerType class FilterNodeSuite extends LocalNodeTest { - private val attributes = Seq( - AttributeReference("k", IntegerType)(), - AttributeReference("v", IntegerType)()) - test("basic") { - val n = 100 + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { val cond = 'k % 2 === 0 - val inputData = (1 to n).map { i => (i, i) }.toArray - val inputNode = new DummyNode(attributes, inputData) + val inputNode = new DummyNode(kvIntAttributes, inputData) val filterNode = new FilterNode(conf, cond, inputNode) val resolvedNode = resolveExpressions(filterNode) val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } @@ -42,16 +35,11 @@ class FilterNodeSuite extends LocalNodeTest { } test("empty") { - val cond = 'k % 2 === 0 - val inputData = Array.empty[(Int, Int)] - val inputNode = new DummyNode(attributes, inputData) - val filterNode = new FilterNode(conf, cond, inputNode) - val resolvedNode = resolveExpressions(filterNode) - val expectedOutput = inputData - val actualOutput = resolvedNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala index 7deaa375fcfc2..c0ad2021b204a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -17,19 +17,21 @@ package org.apache.spark.sql.execution.local -class IntersectNodeSuite extends LocalNodeTest { - import testImplicits._ +class IntersectNodeSuite extends LocalNodeTest { test("basic") { - val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") - - checkAnswer2( - input1, - input2, - (node1, node2) => IntersectNode(conf, node1, node2), - input1.intersect(input2).collect() - ) + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 3b183902007e4..fb790636a3689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { +class LimitNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer( - testData, - node => LimitNode(conf, 10, node), - testData.limit(10).collect() - ) + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer( - emptyTestData, - node => LimitNode(conf, 10, node), - emptyTestData.limit(10).collect() - ) + testLimit() } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala index 01cee7f3dbe8d..0d1ed99eec6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -17,19 +17,12 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.types.IntegerType - -class LocalNodeSuite extends SparkFunSuite { - private val attributes = Seq( - AttributeReference("k", IntegerType)(), - AttributeReference("v", IntegerType)()) +class LocalNodeSuite extends LocalNodeTest { private val data = (1 to 100).map { i => (i, i) }.toArray test("basic open, next, fetch, close") { - val node = new DummyNode(attributes, data) + val node = new DummyNode(kvIntAttributes, data) assert(!node.isOpen) node.open() assert(node.isOpen) @@ -49,7 +42,7 @@ class LocalNodeSuite extends SparkFunSuite { } test("asIterator") { - val node = new DummyNode(attributes, data) + val node = new DummyNode(kvIntAttributes, data) val iter = node.asIterator node.open() data.foreach { case (k, v) => @@ -68,7 +61,7 @@ class LocalNodeSuite extends SparkFunSuite { } test("collect") { - val node = new DummyNode(attributes, data) + val node = new DummyNode(kvIntAttributes, data) node.open() val collected = node.collect() assert(collected.size === data.size) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 65c4589377063..2816eb5612659 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -23,11 +23,17 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.IntegerType class LocalNodeTest extends SparkFunSuite with SharedSQLContext { def conf: SQLConf = sqlContext.conf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + /** * Runs the LocalNode and makes sure the answer matches the expected result. * @param input the input data to be used. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index 38e0a230c46d8..02ecb23d34b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -17,28 +17,33 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} -class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val output = testData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - testData, - node => ProjectNode(conf, columns, node), - testData.select("value", "key").collect() - ) +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val output = emptyTestData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - emptyTestData, - node => ProjectNode(conf, columns, node), - emptyTestData.select("value", "key").collect() - ) + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) } } From a93a2603ce19b60b13a5cd7235d22b28620caea1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 01:13:47 -0700 Subject: [PATCH 3/8] TakeOrderedAndProject + Sample --- .../sql/execution/local/SampleNode.scala | 18 +++---- .../sql/execution/local/SampleNodeSuite.scala | 35 +++++++----- .../TakeOrderedAndProjectNodeSuite.scala | 53 +++++++++---------- .../sql/execution/local/UnionNodeSuite.scala | 49 +++++++++-------- 4 files changed, 82 insertions(+), 73 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index abf3df1c0c2af..59192de0ad8e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.local -import java.util.Random - import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler, RandomSampler} + /** * Sample the dataset. @@ -51,18 +50,15 @@ case class SampleNode( override def open(): Unit = { child.open() - val (sampler, _seed) = if (withReplacement) { - val random = new Random(seed) + val sampler = + if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, // requiring us to copy the row, which is more expensive than the random number generator. - (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), - // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result - // of DataFrame - random.nextLong()) + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) } else { - (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + new BernoulliCellSampler[InternalRow](lowerBound, upperBound) } - sampler.setSeed(_seed) + sampler.setSeed(seed) iterator = sampler.sample(child.asIterator) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index 87a7da453999c..a3e83bbd51457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -17,21 +17,32 @@ package org.apache.spark.sql.execution.local -class SampleNodeSuite extends LocalNodeTest { +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + - import testImplicits._ +class SampleNodeSuite extends LocalNodeTest { private def testSample(withReplacement: Boolean): Unit = { - test(s"withReplacement: $withReplacement") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), - input.sample(withReplacement, 0.3, seed).collect() - ) + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index ff28b24eeff14..cfc28da66e2f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -17,38 +17,37 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} +import scala.util.Random -class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder - import testImplicits._ - private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - sortOrder - } +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { - private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { - val testCaseName = if (desc) "desc" else "asc" - test(testCaseName) { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val sortColumn = if (desc) input.col("key").desc else input.col("key") - checkAnswer( - input, - node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), - input.sort(sortColumn).limit(5).collect() - ) + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + // TODO: re-enable me once TakeOrderedAndProjectNode can return things in sorted order. + // This test is ignored because the node currently just returns the items in the order + // maintained by the underlying min / max heap, but we expect sorted order. + ignore(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) } } - testTakeOrderedAndProjectNode(desc = false) - testTakeOrderedAndProjectNode(desc = true) + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index eedd7320900f9..666b0235c061d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -17,36 +17,39 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { +class UnionNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer2( - testData, - testData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - testData.unionAll(testData).collect() - ) + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer2( - emptyTestData, - emptyTestData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - emptyTestData.unionAll(emptyTestData).collect() - ) + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) } - test("complicated union") { - val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, - emptyTestData, emptyTestData, testData, emptyTestData) - doCheckAnswer( - dfs, - nodes => UnionNode(conf, nodes), - dfs.reduce(_.unionAll(_)).collect() - ) + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) } } From 473e3eba60034d04ffd4ccb7f31ed2b148c84c64 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 12:04:14 -0700 Subject: [PATCH 4/8] HashJoinNodeSuite --- .../execution/local/HashJoinNodeSuite.scala | 154 +++++++++--------- .../sql/execution/local/LocalNodeTest.scala | 14 +- 2 files changed, 78 insertions(+), 90 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 78d891351f4a9..cd79d4a5acb9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -18,99 +18,91 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.execution.joins +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.{IntegerType, StringType} + class HashJoinNodeSuite extends LocalNodeTest { + private val names = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("name", StringType)()) + private val orders = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("orders", IntegerType)()) - import testImplicits._ + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin( + testNamePrefix: String, + buildSide: BuildSide, + unsafeAndCodegen: Boolean): Unit = { + val leftData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { - test(s"$suiteName: inner join with one match per row") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(upperCaseData.col("N").expr), - Seq(lowerCaseData.col("n").expr), - joins.BuildLeft, - node1, - node2) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N").collect() - ) + def runTest( + testName: String, + leftData: Array[(Int, String)], + rightData: Array[(Int, Int)]): Unit = { + test(testName) { + val rightDataMap = rightData.toMap + val leftNode = new DummyNode(names, leftData) + val rightNode = new DummyNode(orders, rightData) + val makeNode = (node1: LocalNode, node2: LocalNode) => { + new HashJoinNode( + conf, Seq(node1.output(0)), Seq(node2.output(0)), buildSide, node1, node2) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftData + .filter { case (k, _) => rightDataMap.contains(k) } + .map { case (k, v) => (k, v, k, rightDataMap(k)) } + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, order) + (row.getInt(0), row.getString(1), row.getInt(2), row.getInt(3)) + } + assert(actualOutput === expectedOutput) } } - test(s"$suiteName: inner join with multiple matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 1).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - x.join(y).where($"x.a" === $"y.a").collect() - ) - } - } + runTest( + s"$testNamePrefix: empty", + leftData, + Array.empty[(Int, Int)]) - test(s"$suiteName: inner join, no matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 2).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - Nil - ) - } - } + runTest( + s"$testNamePrefix: no matches", + leftData, + (10000 to 100100).map { i => (i, i) }.toArray) + + runTest( + s"$testNamePrefix: one match per row", + leftData, + (50 to 100).map { i => (i, i * 1000) }.toArray) - test(s"$suiteName: big inner join, 4 matches per row") { - withSQLConf(confPairs: _*) { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as("x") - val bigDataY = bigData.as("y") + runTest( + s"$testNamePrefix: multiple matches per row", + leftData, + (1 to 100) + .flatMap { i => Seq(i, i / 2, i / 3, i / 5, i / 8) } + .distinct + .map { i => (i, i) } + .toArray) + } - checkAnswer2( - bigDataX, - bigDataY, - wrapForUnsafe( - (node1, node2) => - HashJoinNode( - conf, - Seq(bigDataX.col("key").expr), - Seq(bigDataY.col("key").expr), - joins.BuildLeft, - node1, - node2) - ), - bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) + // Test all combinations of build sides and whether unsafe is enabled + Seq(false, true).foreach { unsafeAndCodegen => + val simpleOrUnsafe = if (unsafeAndCodegen) "unsafe" else "simple" + Seq(BuildLeft, BuildRight).foreach { buildSide => + val leftOrRight = buildSide match { + case BuildLeft => "left" + case BuildRight => "right" } + testJoin(s"$simpleOrUnsafe (build $leftOrRight)", buildSide, unsafeAndCodegen) } } - joinSuite( - "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 6b03667817695..2124ba1a038c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -36,15 +36,11 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { protected def wrapForUnsafe( f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) } } From 8364d237ecd4101f9bf5a68c9cd4ff2bb9550363 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 15:18:08 -0700 Subject: [PATCH 5/8] NestedLoopJoinNodeSuite --- .../execution/local/HashJoinNodeSuite.scala | 119 +++---- .../sql/execution/local/LocalNodeTest.scala | 10 +- .../local/NestedLoopJoinNodeSuite.scala | 316 ++++++------------ 3 files changed, 174 insertions(+), 271 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index cd79d4a5acb9c..febc753886331 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -18,90 +18,79 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.types.{IntegerType, StringType} class HashJoinNodeSuite extends LocalNodeTest { - private val names = Seq( - AttributeReference("id", IntegerType)(), - AttributeReference("name", StringType)()) - private val orders = Seq( - AttributeReference("id", IntegerType)(), - AttributeReference("orders", IntegerType)()) + + // Test all combinations of the two dimensions: with/out unsafe and build sides + val maybeUnsafeAndCodegen = Seq(false, true) + val buildSides = Seq(BuildLeft, BuildRight) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + testJoin(unsafeAndCodegen, buildSide) + } + } /** * Test inner hash join with varying degrees of matches. */ private def testJoin( - testNamePrefix: String, - buildSide: BuildSide, - unsafeAndCodegen: Boolean): Unit = { - val leftData = (1 to 100).map { i => (i, "burger" + i) }.toArray + unsafeAndCodegen: Boolean, + buildSide: BuildSide): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray val conf = new SQLConf conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) - def runTest( - testName: String, - leftData: Array[(Int, String)], - rightData: Array[(Int, Int)]): Unit = { - test(testName) { - val rightDataMap = rightData.toMap - val leftNode = new DummyNode(names, leftData) - val rightNode = new DummyNode(orders, rightData) - val makeNode = (node1: LocalNode, node2: LocalNode) => { - new HashJoinNode( - conf, Seq(node1.output(0)), Seq(node2.output(0)), buildSide, node1, node2) - } - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode - val hashJoinNode = makeUnsafeNode(leftNode, rightNode) - val expectedOutput = leftData - .filter { case (k, _) => rightDataMap.contains(k) } - .map { case (k, v) => (k, v, k, rightDataMap(k)) } - val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, order) - (row.getInt(0), row.getString(1), row.getInt(2), row.getInt(3)) - } - assert(actualOutput === expectedOutput) + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions(new HashJoinNode( + conf, Seq('id1), Seq('id2), buildSide, node1, node2)) } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput === expectedOutput) } - runTest( - s"$testNamePrefix: empty", - leftData, - Array.empty[(Int, Int)]) - - runTest( - s"$testNamePrefix: no matches", - leftData, - (10000 to 100100).map { i => (i, i) }.toArray) + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) + } - runTest( - s"$testNamePrefix: one match per row", - leftData, - (50 to 100).map { i => (i, i * 1000) }.toArray) + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) + } - runTest( - s"$testNamePrefix: multiple matches per row", - leftData, - (1 to 100) - .flatMap { i => Seq(i, i / 2, i / 3, i / 5, i / 8) } - .distinct - .map { i => (i, i) } - .toArray) - } + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } - // Test all combinations of build sides and whether unsafe is enabled - Seq(false, true).foreach { unsafeAndCodegen => - val simpleOrUnsafe = if (unsafeAndCodegen) "unsafe" else "simple" - Seq(BuildLeft, BuildRight).foreach { buildSide => - val leftOrRight = buildSide match { - case BuildLeft => "left" - case BuildRight => "right" - } - testJoin(s"$simpleOrUnsafe (build $leftOrRight)", buildSide, unsafeAndCodegen) + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 2124ba1a038c8..396ff376144f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, StringType} class LocalNodeTest extends SparkFunSuite with SharedSQLContext { @@ -34,6 +34,14 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { AttributeReference("k", IntegerType)(), AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) + protected def wrapForUnsafe( f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { (left: LocalNode, right: LocalNode) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index b1ef26ba82f16..0b35df89baf16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -18,222 +18,128 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class NestedLoopJoinNodeSuite extends LocalNodeTest { - import testImplicits._ - - private def joinSuite( - suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { - test(s"$suiteName: left outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("n") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - upperCaseData.col("N") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + val maybeUnsafeAndCodegen = Seq(false, true) + val buildSides = Seq(BuildLeft, BuildRight) + val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(unsafeAndCodegen, buildSide, joinType) } } + } - test(s"$suiteName: right outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide, + joinType: JoinType): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput.toSet === expectedOutput.toSet) } - test(s"$suiteName: full outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) - } + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") } } - joinSuite( - "general-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "general-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "tungsten-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") - joinSuite( - "tungsten-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } From 36fb03859549e6415d0573fedd770dcc3d0bf61e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 15:33:30 -0700 Subject: [PATCH 6/8] ExpandNode --- .../sql/execution/local/ExpandNodeSuite.scala | 54 +++++++++---------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala index cfa7f3f6dcb97..bbd94d8da2d11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -17,35 +17,33 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.catalyst.dsl.expressions._ + + class ExpandNodeSuite extends LocalNodeTest { - import testImplicits._ - - test("expand") { - val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") - checkAnswer( - input, - node => - ExpandNode(conf, Seq( - Seq( - input.col("key") + input.col("value"), input.col("key") - input.col("value") - ).map(_.expr), - Seq( - input.col("key") * input.col("value"), input.col("key") / input.col("value") - ).map(_.expr) - ), node.output, node), - Seq( - (2, 0), - (1, 1), - (4, 0), - (4, 1), - (6, 0), - (9, 1), - (8, 0), - (16, 1), - (10, 0), - (25, 1) - ).toDF().collect() - ) + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + } From 060e5e6db0c0cb1eaaab9affb7b6cc7a1dca5ef0 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 15:42:01 -0700 Subject: [PATCH 7/8] Delete all obsolete code in LocalNodeTest --- .../execution/local/HashJoinNodeSuite.scala | 4 +- .../sql/execution/local/LocalNodeTest.scala | 135 ++---------------- .../local/NestedLoopJoinNodeSuite.scala | 6 +- 3 files changed, 13 insertions(+), 132 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index febc753886331..5c1bdb088eeed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} class HashJoinNodeSuite extends LocalNodeTest { // Test all combinations of the two dimensions: with/out unsafe and build sides - val maybeUnsafeAndCodegen = Seq(false, true) - val buildSides = Seq(BuildLeft, BuildRight) + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => buildSides.foreach { buildSide => testJoin(unsafeAndCodegen, buildSide) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 396ff376144f6..098050bcd2236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,31 +17,31 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLConf} -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.types.{IntegerType, StringType} -class LocalNodeTest extends SparkFunSuite with SharedSQLContext { - def conf: SQLConf = sqlContext.conf +class LocalNodeTest extends SparkFunSuite { + protected val conf: SQLConf = new SQLConf protected val kvIntAttributes = Seq( AttributeReference("k", IntegerType)(), AttributeReference("v", IntegerType)()) - protected val joinNameAttributes = Seq( AttributeReference("id1", IntegerType)(), AttributeReference("name", StringType)()) - protected val joinNicknameAttributes = Seq( AttributeReference("id2", IntegerType)(), AttributeReference("nickname", StringType)()) + /** + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows + */ protected def wrapForUnsafe( f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { (left: LocalNode, right: LocalNode) => { @@ -52,78 +52,6 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { } } - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer( - input: DataFrame, - nodeFunction: LocalNode => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - input :: Nil, - nodes => nodeFunction(nodes.head), - expectedAnswer, - sortAnswers) - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer2( - left: DataFrame, - right: DataFrame, - nodeFunction: (LocalNode, LocalNode) => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - left :: right :: Nil, - nodes => nodeFunction(nodes(0), nodes(1)), - expectedAnswer, - sortAnswers) - } - - /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def doCheckAnswer( - input: Seq[DataFrame], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - LocalNodeTest.checkAnswer( - input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - - protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { - new SeqScanNode( - conf, - df.queryExecution.sparkPlan.output, - df.queryExecution.toRdd.map(_.copy()).collect()) - } - /** * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. */ @@ -140,50 +68,3 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { } } - -/** - * Helper methods for writing tests of individual local physical operators. - */ -object LocalNodeTest { - - /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - def checkAnswer( - input: Seq[SeqScanNode], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { - - val outputNode = nodeFunction(input) - - val outputResult: Seq[Row] = try { - outputNode.collect() - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing local plan: - | $outputNode - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => - s""" - | Results do not match for local plan: - | $outputNode - | $errorMessage - """.stripMargin - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 0b35df89baf16..40299d9d5ee37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} class NestedLoopJoinNodeSuite extends LocalNodeTest { // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types - val maybeUnsafeAndCodegen = Seq(false, true) - val buildSides = Seq(BuildLeft, BuildRight) - val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => buildSides.foreach { buildSide => joinTypes.foreach { joinType => From 372ab5fc77043274378471d0c1dd7d39b83741a6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 15:45:02 -0700 Subject: [PATCH 8/8] Fix TakeOrderedAndProjectNode --- .../sql/execution/local/TakeOrderedAndProjectNode.scala | 2 +- .../sql/execution/local/TakeOrderedAndProjectNodeSuite.scala | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index 53f1dcc65d8cf..ae672fbca8d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode( } // Close it eagerly since we don't need it. child.close() - iterator = queue.iterator + iterator = queue.toArray.sorted(ord).iterator } override def next(): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index cfc28da66e2f9..42ebc7bfcaadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -28,10 +28,7 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { private def testTakeOrderedAndProject(desc: Boolean): Unit = { val limit = 10 val ascOrDesc = if (desc) "desc" else "asc" - // TODO: re-enable me once TakeOrderedAndProjectNode can return things in sorted order. - // This test is ignored because the node currently just returns the items in the order - // maintained by the underlying min / max heap, but we expect sorted order. - ignore(ascOrDesc) { + test(ascOrDesc) { val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray val inputNode = new DummyNode(kvIntAttributes, inputData) val firstColumn = inputNode.output(0)