Skip to content

Commit

Permalink
Address Andrew's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Sep 11, 2015
1 parent 4090902 commit a3270b0
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
*/
def close(): Unit

/**
* Returns the content through the [[Iterator]] interface.
*/
final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)

/**
* Returns the content of the iterator from the beginning to the end in the form of a Scala Seq.
*/
Expand Down Expand Up @@ -108,7 +113,6 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
}
}

def toIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ case class SampleNode(
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
}
sampler.setSeed(_seed)
iterator = sampler.sample(child.toIterator)
iterator = sampler.sample(child.asIterator)
}

override def next(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,21 @@ case class TakeOrderedAndProjectNode(
projectList: Option[Seq[NamedExpression]],
child: LocalNode) extends UnaryLocalNode(conf) {

override def output: Seq[Attribute] = {
val projectOutput = projectList.map(_.map(_.toAttribute))
projectOutput.getOrElse(child.output)
}

private[this] var projection: Option[Projection] = _

private[this] var ord: InterpretedOrdering = _

private[this] var iterator: Iterator[InternalRow] = _

private[this] var currentRow: InternalRow = _

override def output: Seq[Attribute] = {
val projectOutput = projectList.map(_.map(_.toAttribute))
projectOutput.getOrElse(child.output)
}

override def open(): Unit = {
child.open()
projection = projectList.map(new InterpretedProjection(_, child.output))
ord = new InterpretedOrdering(sortOrder, child.output)
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse)
while (child.next()) {
queue += child.fetch()
Expand All @@ -58,7 +56,10 @@ case class TakeOrderedAndProjectNode(
override def next(): Boolean = {
if (iterator.hasNext) {
val _currentRow = iterator.next()
currentRow = projection.map(p => p(_currentRow)).getOrElse(_currentRow)
currentRow = projection match {
case Some(p) => p(_currentRow)
case None => _currentRow
}
true
} else {
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,24 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}

class SampleNodeSuite extends LocalNodeTest {

import testImplicits._

def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
private def testSample(withReplacement: Boolean): Unit = {
test(s"withReplacement: $withReplacement") {
val seed = 0L
val input = sqlContext.sparkContext.
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
toDF("key", "value")
checkAnswer(
input,
node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
input.sample(withReplacement, 0.3, seed).collect()
)
}
sortOrder
}

test("withReplacement: true") {
val seed = 0L
val input = sqlContext.sparkContext.
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
toDF("key", "value")
checkAnswer(
input,
node => SampleNode(conf, 0.0, 0.3, true, seed, node),
input.sample(true, 0.3, seed).collect()
)
}

test("withReplacement: false") {
val seed = 0L
val input = sqlContext.sparkContext.
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
toDF("key", "value")
checkAnswer(
input,
node => SampleNode(conf, 0.0, 0.3, false, seed, node),
input.sample(false, 0.3, seed).collect()
)
}
testSample(withReplacement = true)
testSample(withReplacement = false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {

import testImplicits._

def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
Expand All @@ -36,22 +36,19 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
sortOrder
}

test("asc") {
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
checkAnswer(
input,
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key")), None, node),
input.sort(input.col("key")).limit(5).collect()
)
private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
val testCaseName = if (desc) "desc" else "asc"
test(testCaseName) {
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
val sortColumn = if (desc) input.col("key").desc else input.col("key")
checkAnswer(
input,
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
input.sort(sortColumn).limit(5).collect()
)
}
}

test("desc") {
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
checkAnswer(
input,
node =>
TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key").desc), None, node),
input.sort(input.col("key").desc).limit(5).collect()
)
}
testTakeOrderedAndProjectNode(desc = false)
testTakeOrderedAndProjectNode(desc = true)
}

0 comments on commit a3270b0

Please sign in to comment.