Skip to content

Commit

Permalink
[SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK…
Browse files Browse the repository at this point in the history
…, sample and intersect operators

This PR is in conflict with #8535. I will update this one when #8535 gets merged.

Author: zsxwing <zsxwing@gmail.com>

Closes #8573 from zsxwing/more-local-operators.
  • Loading branch information
zsxwing authored and Andrew Or committed Sep 11, 2015
1 parent 1eede3b commit e626ac5
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the QueryPlan
* @param child the SparkPlan
*/
@DeveloperApi
case class Sample(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import scala.collection.mutable

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute

case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode)
extends BinaryLocalNode(conf) {

override def output: Seq[Attribute] = left.output

private[this] var leftRows: mutable.HashSet[InternalRow] = _

private[this] var currentRow: InternalRow = _

override def open(): Unit = {
left.open()
leftRows = mutable.HashSet[InternalRow]()
while (left.next()) {
leftRows += left.fetch().copy()
}
left.close()
right.open()
}

override def next(): Boolean = {
currentRow = null
while (currentRow == null && right.next()) {
currentRow = right.fetch()
if (!leftRows.contains(currentRow)) {
currentRow = null
}
}
currentRow != null
}

override def fetch(): InternalRow = currentRow

override def close(): Unit = {
left.close()
right.close()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
*/
def close(): Unit

/**
* Returns the content through the [[Iterator]] interface.
*/
final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)

/**
* Returns the content of the iterator from the beginning to the end in the form of a Scala Seq.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import java.util.Random

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/**
* Sample the dataset.
*
* @param conf the SQLConf
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the LocalNode
*/
case class SampleNode(
conf: SQLConf,
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: LocalNode) extends UnaryLocalNode(conf) {

override def output: Seq[Attribute] = child.output

private[this] var iterator: Iterator[InternalRow] = _

private[this] var currentRow: InternalRow = _

override def open(): Unit = {
child.open()
val (sampler, _seed) = if (withReplacement) {
val random = new Random(seed)
// Disable gap sampling since the gap sampling method buffers two rows internally,
// requiring us to copy the row, which is more expensive than the random number generator.
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
// of DataFrame
random.nextLong())
} else {
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
}
sampler.setSeed(_seed)
iterator = sampler.sample(child.asIterator)
}

override def next(): Boolean = {
if (iterator.hasNext) {
currentRow = iterator.next()
true
} else {
false
}
}

override def fetch(): InternalRow = currentRow

override def close(): Unit = child.close()

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.util.BoundedPriorityQueue

case class TakeOrderedAndProjectNode(
conf: SQLConf,
limit: Int,
sortOrder: Seq[SortOrder],
projectList: Option[Seq[NamedExpression]],
child: LocalNode) extends UnaryLocalNode(conf) {

private[this] var projection: Option[Projection] = _
private[this] var ord: InterpretedOrdering = _
private[this] var iterator: Iterator[InternalRow] = _
private[this] var currentRow: InternalRow = _

override def output: Seq[Attribute] = {
val projectOutput = projectList.map(_.map(_.toAttribute))
projectOutput.getOrElse(child.output)
}

override def open(): Unit = {
child.open()
projection = projectList.map(new InterpretedProjection(_, child.output))
ord = new InterpretedOrdering(sortOrder, child.output)
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse)
while (child.next()) {
queue += child.fetch()
}
// Close it eagerly since we don't need it.
child.close()
iterator = queue.iterator
}

override def next(): Boolean = {
if (iterator.hasNext) {
val _currentRow = iterator.next()
currentRow = projection match {
case Some(p) => p(_currentRow)
case None => _currentRow
}
true
} else {
false
}
}

override def fetch(): InternalRow = currentRow

override def close(): Unit = child.close()

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

class IntersectNodeSuite extends LocalNodeTest {

import testImplicits._

test("basic") {
val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value")

checkAnswer2(
input1,
input2,
(node1, node2) => IntersectNode(conf, node1, node2),
input1.intersect(input2).collect()
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

class SampleNodeSuite extends LocalNodeTest {

import testImplicits._

private def testSample(withReplacement: Boolean): Unit = {
test(s"withReplacement: $withReplacement") {
val seed = 0L
val input = sqlContext.sparkContext.
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
toDF("key", "value")
checkAnswer(
input,
node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
input.sample(withReplacement, 0.3, seed).collect()
)
}
}

testSample(withReplacement = true)
testSample(withReplacement = false)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}

class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {

import testImplicits._

private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
sortOrder
}

private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
val testCaseName = if (desc) "desc" else "asc"
test(testCaseName) {
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
val sortColumn = if (desc) input.col("key").desc else input.col("key")
checkAnswer(
input,
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
input.sort(sortColumn).limit(5).collect()
)
}
}

testTakeOrderedAndProjectNode(desc = false)
testTakeOrderedAndProjectNode(desc = true)
}

0 comments on commit e626ac5

Please sign in to comment.