Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext #8764

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.types.StructType

/**
Expand All @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
* Before consuming the iterator, open function must be called.
* After consuming the iterator, close function must be called.
*/
abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging {

protected val codegenEnabled: Boolean = conf.codegenEnabled

protected val unsafeEnabled: Boolean = conf.unsafeEnabled

lazy val schema: StructType = StructType.fromAttributes(output)

private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")

def output: Seq[Attribute]

/**
* Initializes the iterator state. Must be called before calling `next()`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.sql.execution.local

import java.util.Random

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}


/**
* Sample the dataset.
*
Expand Down Expand Up @@ -51,18 +50,15 @@ case class SampleNode(

override def open(): Unit = {
child.open()
val (sampler, _seed) = if (withReplacement) {
val random = new Random(seed)
val sampler =
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
// requiring us to copy the row, which is more expensive than the random number generator.
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
// of DataFrame
random.nextLong())
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
} else {
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
}
sampler.setSeed(_seed)
sampler.setSeed(seed)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing I had to remove this to make testing deterministic. Looking at this further I still don't see the point of introducing another layer of randomness here. What change in behavior does this entail?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using DataFrame.sample to test SampleNode and it mocked the behavior of DataFrame.sample(withReplacement = true). Since you don't use DataFrame to test it now, I agree that we can remove this tricky logic.

iterator = sampler.sample(child.asIterator)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode(
}
// Close it eagerly since we don't need it.
child.close()
iterator = queue.iterator
iterator = queue.toArray.sorted(ord).iterator
}

override def next(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ object SparkPlanTest {
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan.transformExpressions {
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation

/**
* A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
*/
private[local] case class DummyNode(
output: Seq[Attribute],
relation: LocalRelation,
conf: SQLConf)
extends LocalNode(conf) {

import DummyNode._

private var index: Int = CLOSED
private val input: Seq[InternalRow] = relation.data

def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
this(output, LocalRelation.fromProduct(output, data), conf)
}

def isOpen: Boolean = index != CLOSED

override def children: Seq[LocalNode] = Seq.empty

override def open(): Unit = {
index = -1
}

override def next(): Boolean = {
index += 1
index < input.size
}

override def fetch(): InternalRow = {
assert(index >= 0 && index < input.size)
input(index)
}

override def close(): Unit = {
index = CLOSED
}
}

private object DummyNode {
val CLOSED: Int = Int.MinValue
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,33 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.dsl.expressions._


class ExpandNodeSuite extends LocalNodeTest {

import testImplicits._

test("expand") {
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
checkAnswer(
input,
node =>
ExpandNode(conf, Seq(
Seq(
input.col("key") + input.col("value"), input.col("key") - input.col("value")
).map(_.expr),
Seq(
input.col("key") * input.col("value"), input.col("key") / input.col("value")
).map(_.expr)
), node.output, node),
Seq(
(2, 0),
(1, 1),
(4, 0),
(4, 1),
(6, 0),
(9, 1),
(8, 0),
(16, 1),
(10, 0),
(25, 1)
).toDF().collect()
)
private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = {
val inputNode = new DummyNode(kvIntAttributes, inputData)
val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v))
val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode)
val resolvedNode = resolveExpressions(expandNode)
val expectedOutput = {
val firstHalf = inputData.map { case (k, v) => (k + v, k - v) }
val secondHalf = inputData.map { case (k, v) => (k * v, k / v) }
firstHalf ++ secondHalf
}
val actualOutput = resolvedNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput.toSet === expectedOutput.toSet)
}

test("empty") {
testExpand()
}

test("basic") {
testExpand((1 to 100).map { i => (i, i * 1000) }.toArray)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,29 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.catalyst.dsl.expressions._

class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {

test("basic") {
val condition = (testData.col("key") % 2) === 0
checkAnswer(
testData,
node => FilterNode(conf, condition.expr, node),
testData.filter(condition).collect()
)
class FilterNodeSuite extends LocalNodeTest {

private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
val cond = 'k % 2 === 0
val inputNode = new DummyNode(kvIntAttributes, inputData)
val filterNode = new FilterNode(conf, cond, inputNode)
val resolvedNode = resolveExpressions(filterNode)
val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
val actualOutput = resolvedNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}

test("empty") {
val condition = (emptyTestData.col("key") % 2) === 0
checkAnswer(
emptyTestData,
node => FilterNode(conf, condition.expr, node),
emptyTestData.filter(condition).collect()
)
testFilter()
}

test("basic") {
testFilter((1 to 100).map { i => (i, i) }.toArray)
}

}
Loading