diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 6428b0bf16f3c..a2c275db9b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** + * Returns the content through the [[Iterator]] interface. + */ + final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) + /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ @@ -108,7 +113,6 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging } } - def toIterator: Iterator[InternalRow] = new LocalNodeIterator(this) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index 56a91258fe5cf..abf3df1c0c2af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -63,7 +63,7 @@ case class SampleNode( (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) } sampler.setSeed(_seed) - iterator = sampler.sample(child.toIterator) + iterator = sampler.sample(child.asIterator) } override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index 1de6b6f69c0c6..53f1dcc65d8cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -29,23 +29,21 @@ case class TakeOrderedAndProjectNode( projectList: Option[Seq[NamedExpression]], child: LocalNode) extends UnaryLocalNode(conf) { - override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) - } - private[this] var projection: Option[Projection] = _ - private[this] var ord: InterpretedOrdering = _ - private[this] var iterator: Iterator[InternalRow] = _ - private[this] var currentRow: InternalRow = _ + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + override def open(): Unit = { child.open() projection = projectList.map(new InterpretedProjection(_, child.output)) ord = new InterpretedOrdering(sortOrder, child.output) + // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) while (child.next()) { queue += child.fetch() @@ -58,7 +56,10 @@ case class TakeOrderedAndProjectNode( override def next(): Boolean = { if (iterator.hasNext) { val _currentRow = iterator.next() - currentRow = projection.map(p => p(_currentRow)).getOrElse(_currentRow) + currentRow = projection match { + case Some(p) => p(_currentRow) + case None => _currentRow + } true } else { false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index d62aeb76b45e4..87a7da453999c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -17,46 +17,24 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} - class SampleNodeSuite extends LocalNodeTest { import testImplicits._ - def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } + private def testSample(withReplacement: Boolean): Unit = { + test(s"withReplacement: $withReplacement") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), + input.sample(withReplacement, 0.3, seed).collect() + ) } - sortOrder - } - - test("withReplacement: true") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, true, seed, node), - input.sample(true, 0.3, seed).collect() - ) } - test("withReplacement: false") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, false, seed, node), - input.sample(false, 0.3, seed).collect() - ) - } + testSample(withReplacement = true) + testSample(withReplacement = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index 8dfa98001f7c5..ff28b24eeff14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -24,7 +24,7 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { import testImplicits._ - def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => @@ -36,22 +36,19 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { sortOrder } - test("asc") { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - checkAnswer( - input, - node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key")), None, node), - input.sort(input.col("key")).limit(5).collect() - ) + private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { + val testCaseName = if (desc) "desc" else "asc" + test(testCaseName) { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val sortColumn = if (desc) input.col("key").desc else input.col("key") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), + input.sort(sortColumn).limit(5).collect() + ) + } } - test("desc") { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - checkAnswer( - input, - node => - TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key").desc), None, node), - input.sort(input.col("key").desc).limit(5).collect() - ) - } + testTakeOrderedAndProjectNode(desc = false) + testTakeOrderedAndProjectNode(desc = true) }