From d25226573b47257dc6e3d79d024e84fe205d104a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 2 Sep 2015 21:48:47 +0800 Subject: [PATCH 1/4] Implement the local sample operator --- .../spark/sql/execution/basicOperators.scala | 2 +- .../spark/sql/execution/local/LocalNode.scala | 33 ++++++++ .../sql/execution/local/SampleNode.scala | 79 +++++++++++++++++++ .../sql/execution/local/LocalNodeTest.scala | 22 +++++- .../sql/execution/local/SampleNodeSuite.scala | 60 ++++++++++++++ 5 files changed, 192 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 3f68b05a24f44..bf6d44c098ee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed - * @param child the QueryPlan + * @param child the SparkPlan */ @DeveloperApi case class Sample( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 1c4469acbf264..610fabe8ec6a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -73,6 +73,30 @@ abstract class LocalNode extends TreeNode[LocalNode] { } result } + + def toIterator: Iterator[InternalRow] = new Iterator[InternalRow] { + + private var currentRow: InternalRow = null + + override def hasNext: Boolean = { + if (currentRow == null) { + if (LocalNode.this.next()) { + currentRow = fetch() + true + } else { + false + } + } else { + true + } + } + + override def next(): InternalRow = { + val r = currentRow + currentRow = null + r + } + } } @@ -87,3 +111,12 @@ abstract class UnaryLocalNode extends LocalNode { override def children: Seq[LocalNode] = Seq(child) } + +abstract class BinaryLocalNode extends LocalNode { + + def left: LocalNode + + def right: LocalNode + + override def children: Seq[LocalNode] = Seq(left, right) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala new file mode 100644 index 0000000000000..d2ba91a47151c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import java.util.Random + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + +/** + * Sample the dataset. + * + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LocalNode + */ +case class SampleNode( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LocalNode) extends UnaryLocalNode { + + override def output: Seq[Attribute] = child.output + + private[this] var iterator: Iterator[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + child.open() + val (sampler, _seed) = if (withReplacement) { + val random = new Random(seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result + // of DataFrame + random.nextLong()) + } else { + (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + } + sampler.setSeed(_seed) + iterator = sampler.sample(child.toIterator) + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 95f06081bd0a8..41afaec570883 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,13 +17,29 @@ package org.apache.spark.sql.execution.local +import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class LocalNodeTest extends SparkFunSuite { +class LocalNodeTest extends SparkFunSuite with SharedSQLContext { + + /** + * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). + */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(rdd)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + sqlContext.implicits.localSeqToDataFrameHolder(data) + } /** * Runs the LocalNode and makes sure the answer matches the expected result. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala new file mode 100644 index 0000000000000..a7cb518362c1a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} + +class SampleNodeSuite extends LocalNodeTest { + + def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + sortOrder + } + + test("withReplacement: true") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(0.0, 0.3, true, seed, node), + input.sample(true, 0.3, seed).collect() + ) + } + + test("withReplacement: false") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(0.0, 0.3, false, seed, node), + input.sample(false, 0.3, seed).collect() + ) + } +} From d1acc2a264a74e579c51c26f8d9943a5c710f472 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 2 Sep 2015 21:49:11 +0800 Subject: [PATCH 2/4] Implement the local TopK operator --- .../local/TakeOrderedAndProjectNode.scala | 70 +++++++++++++++++++ .../TakeOrderedAndProjectNodeSuite.scala | 54 ++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala new file mode 100644 index 0000000000000..9c90d834e6a29 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.util.BoundedPriorityQueue + +case class TakeOrderedAndProjectNode( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: LocalNode) extends UnaryLocalNode { + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + private[this] var projection: Option[Projection] = _ + + private[this] var ord: InterpretedOrdering = _ + + private[this] var iterator: Iterator[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + child.open() + projection = projectList.map(new InterpretedProjection(_, child.output)) + ord = new InterpretedOrdering(sortOrder, child.output) + val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) + while (child.next()) { + queue += child.fetch() + } + // Close it eagerly since we don't need it. + child.close() + iterator = queue.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + val _currentRow = iterator.next() + currentRow = projection.map(p => p(_currentRow)).getOrElse(_currentRow) + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 0000000000000..c4481aa4cdbac --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + sortOrder + } + + test("asc") { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(5, columnToSortOrder(input.col("key")), None, node), + input.sort(input.col("key")).limit(5).collect() + ) + } + + test("desc") { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(5, columnToSortOrder(input.col("key").desc), None, node), + input.sort(input.col("key").desc).limit(5).collect() + ) + } +} From 4ccca2ab05ce1508b64c2970379eaf2116b1e303 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 2 Sep 2015 21:49:46 +0800 Subject: [PATCH 3/4] Implement the local intersect operator --- .../sql/execution/local/IntersectNode.scala | 61 +++++++++++++++++++ .../execution/local/IntersectNodeSuite.scala | 33 ++++++++++ 2 files changed, 94 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala new file mode 100644 index 0000000000000..bb619f8f8c156 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -0,0 +1,61 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class IntersectNode(left: LocalNode, right: LocalNode) extends BinaryLocalNode { + + override def output: Seq[Attribute] = left.output + + private[this] var leftRows: mutable.HashSet[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + left.open() + leftRows = mutable.HashSet[InternalRow]() + while (left.next()) { + leftRows += left.fetch().copy() + } + left.close() + right.open() + } + + override def next(): Boolean = { + currentRow = null + while (currentRow == null && right.next()) { + currentRow = right.fetch() + if (!leftRows.contains(currentRow)) { + currentRow = null + } + } + currentRow != null + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 0000000000000..dfcb764046207 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,33 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class IntersectNodeSuite extends LocalNodeTest { + + test("basic") { + val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") + + checkAnswer2( + input1, + input2, + (node1, node2) => IntersectNode(node1, node2), + input1.intersect(input2).collect() + ) + } +} From a3270b0e8470e09cafffcc18579e8b0febdc0ef6 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 11 Sep 2015 17:34:46 +0800 Subject: [PATCH 4/4] Address Andrew's comments --- .../spark/sql/execution/local/LocalNode.scala | 6 ++- .../sql/execution/local/SampleNode.scala | 2 +- .../local/TakeOrderedAndProjectNode.scala | 19 ++++---- .../sql/execution/local/SampleNodeSuite.scala | 48 +++++-------------- .../TakeOrderedAndProjectNodeSuite.scala | 31 ++++++------ 5 files changed, 43 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 6428b0bf16f3c..a2c275db9b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** + * Returns the content through the [[Iterator]] interface. + */ + final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) + /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ @@ -108,7 +113,6 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging } } - def toIterator: Iterator[InternalRow] = new LocalNodeIterator(this) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index 56a91258fe5cf..abf3df1c0c2af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -63,7 +63,7 @@ case class SampleNode( (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) } sampler.setSeed(_seed) - iterator = sampler.sample(child.toIterator) + iterator = sampler.sample(child.asIterator) } override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index 1de6b6f69c0c6..53f1dcc65d8cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -29,23 +29,21 @@ case class TakeOrderedAndProjectNode( projectList: Option[Seq[NamedExpression]], child: LocalNode) extends UnaryLocalNode(conf) { - override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) - } - private[this] var projection: Option[Projection] = _ - private[this] var ord: InterpretedOrdering = _ - private[this] var iterator: Iterator[InternalRow] = _ - private[this] var currentRow: InternalRow = _ + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + override def open(): Unit = { child.open() projection = projectList.map(new InterpretedProjection(_, child.output)) ord = new InterpretedOrdering(sortOrder, child.output) + // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) while (child.next()) { queue += child.fetch() @@ -58,7 +56,10 @@ case class TakeOrderedAndProjectNode( override def next(): Boolean = { if (iterator.hasNext) { val _currentRow = iterator.next() - currentRow = projection.map(p => p(_currentRow)).getOrElse(_currentRow) + currentRow = projection match { + case Some(p) => p(_currentRow) + case None => _currentRow + } true } else { false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index d62aeb76b45e4..87a7da453999c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -17,46 +17,24 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} - class SampleNodeSuite extends LocalNodeTest { import testImplicits._ - def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } + private def testSample(withReplacement: Boolean): Unit = { + test(s"withReplacement: $withReplacement") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), + input.sample(withReplacement, 0.3, seed).collect() + ) } - sortOrder - } - - test("withReplacement: true") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, true, seed, node), - input.sample(true, 0.3, seed).collect() - ) } - test("withReplacement: false") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, false, seed, node), - input.sample(false, 0.3, seed).collect() - ) - } + testSample(withReplacement = true) + testSample(withReplacement = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index 8dfa98001f7c5..ff28b24eeff14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -24,7 +24,7 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { import testImplicits._ - def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => @@ -36,22 +36,19 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { sortOrder } - test("asc") { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - checkAnswer( - input, - node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key")), None, node), - input.sort(input.col("key")).limit(5).collect() - ) + private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { + val testCaseName = if (desc) "desc" else "asc" + test(testCaseName) { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val sortColumn = if (desc) input.col("key").desc else input.col("key") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), + input.sort(sortColumn).limit(5).collect() + ) + } } - test("desc") { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - checkAnswer( - input, - node => - TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key").desc), None, node), - input.sort(input.col("key").desc).limit(5).collect() - ) - } + testTakeOrderedAndProjectNode(desc = false) + testTakeOrderedAndProjectNode(desc = true) }