Skip to content

Commit

Permalink
Merge in a sketch of a unit test for the new sorter (now failing).
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 6, 2015
1 parent 2bd8c9a commit 58f36d0
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,4 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
spillMerger.addSpill(sorter.getSortedIterator());
return spillMerger.getSortedIterator();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ final class UnsafeSorterSpillMerger {
private final PriorityQueue<UnsafeSorterIterator> priorityQueue;

public UnsafeSorterSpillMerger(
final RecordComparator recordComparator,
final PrefixComparator prefixComparator) {
final RecordComparator recordComparator,
final PrefixComparator prefixComparator) {
final Comparator<UnsafeSorterIterator> comparator = new Comparator<UnsafeSorterIterator>() {

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import java.util.NoSuchElementException

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{UnsafeExternalSort, BinaryNode, SparkPlan}
import org.apache.spark.util.collection.CompactBuffer

/**
* :: DeveloperApi ::
* Performs an sort merge join of two child relations.
* TODO(josh): Document
*/
@DeveloperApi
case class UnsafeSortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def output: Seq[Attribute] = left.output ++ right.output

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

// this is to manually construct an ordering that can be used to compare keys from both sides
private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))

override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)

@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)

private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
keys.map(SortOrder(_, Ascending))

protected override def doExecute(): RDD[Row] = {
// Note that we purposely do not require out input to be sorted. Instead, we'll sort it
// ourselves using UnsafeExternalSorter. Not requiring the input to be sorted will prevent the
// Exchange from pushing the sort into the shuffle, which will allow the shuffle to benefit from
// Project Tungsten's shuffle optimizations which currently cannot be applied to shuffles that
// specify a key ordering.

// Only sort if necessary:
val leftOrder = requiredOrders(leftKeys)
val leftResults = {
if (left.outputOrdering == leftOrder) {
left.execute().map(_.copy())
} else {
new UnsafeExternalSort(leftOrder, global = false, left).execute()
}
}
val rightOrder = requiredOrders(rightKeys)
val rightResults = {
if (right.outputOrdering == rightOrder) {
right.execute().map(_.copy())
} else {
new UnsafeExternalSort(rightOrder, global = false, right).execute()
}
}

leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
new Iterator[Row] {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow5
private[this] var leftElement: Row = _
private[this] var rightElement: Row = _
private[this] var leftKey: Row = _
private[this] var rightKey: Row = _
private[this] var rightMatches: CompactBuffer[Row] = _
private[this] var rightPosition: Int = -1
private[this] var stop: Boolean = false
private[this] var matchKey: Row = _

// initialize iterator
initialize()

override final def hasNext: Boolean = nextMatchingPair()

override final def next(): Row = {
if (hasNext) {
// we are using the buffered right rows and run down left iterator
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
rightPosition += 1
if (rightPosition >= rightMatches.size) {
rightPosition = 0
fetchLeft()
if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
stop = false
rightMatches = null
}
}
joinedRow
} else {
// no more result
throw new NoSuchElementException
}
}

private def fetchLeft() = {
if (leftIter.hasNext) {
leftElement = leftIter.next()
println(leftElement)
leftKey = leftKeyGenerator(leftElement)
} else {
leftElement = null
}
}

private def fetchRight() = {
if (rightIter.hasNext) {
rightElement = rightIter.next()
println(right)
rightKey = rightKeyGenerator(rightElement)
} else {
rightElement = null
}
}

private def initialize() = {
fetchLeft()
fetchRight()
}

/**
* Searches the right iterator for the next rows that have matches in left side, and store
* them in a buffer.
*
* @return true if the search is successful, and false if the right iterator runs out of
* tuples.
*/
private def nextMatchingPair(): Boolean = {
if (!stop && rightElement != null) {
// run both side to get the first match pair
while (!stop && leftElement != null && rightElement != null) {
val comparing = keyOrdering.compare(leftKey, rightKey)
// for inner join, we need to filter those null keys
stop = comparing == 0 && !leftKey.anyNull
if (comparing > 0 || rightKey.anyNull) {
fetchRight()
} else if (comparing < 0 || leftKey.anyNull) {
fetchLeft()
}
}
rightMatches = new CompactBuffer[Row]()
if (stop) {
stop = false
// iterate the right side to buffer all rows that matches
// as the records should be ordered, exit when we meet the first that not match
while (!stop && rightElement != null) {
rightMatches += rightElement
fetchRight()
stop = keyOrdering.compare(leftKey, rightKey) != 0
}
if (rightMatches.size > 0) {
rightPosition = 0
matchKey = leftKey
}
}
}
rightMatches != null && rightMatches.size > 0
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.UnsafeExternalSort
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.scalatest.BeforeAndAfterEach

class UnsafeSortMergeJoinSuite extends QueryTest with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData

conf.setConf(SQLConf.SORTMERGE_JOIN, "true")
conf.setConf(SQLConf.CODEGEN_ENABLED, "true")
conf.setConf(SQLConf.UNSAFE_ENABLED, "true")
conf.setConf(SQLConf.EXTERNAL_SORT, "true")
conf.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, "-1")

test("basic sort merge join test") {
val df = upperCaseData.join(lowerCaseData, $"n" === $"N")
print(df.queryExecution.optimizedPlan)
assert(df.queryExecution.sparkPlan.collect {
case smj: UnsafeSortMergeJoin => smj
}.nonEmpty)
checkAnswer(
df,
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
Row(3, "C", 3, "c"),
Row(4, "D", 4, "d")
))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.scalatest.{FunSuite, Matchers}

import org.apache.spark.sql.{SQLConf, SQLContext, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._

class UnsafeExternalSortSuite extends FunSuite with Matchers {

private def createRow(values: Any*): Row = {
new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
}

test("basic sorting") {
val sc = TestSQLContext.sparkContext
val sqlContext = new SQLContext(sc)
sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true")

val schema: StructType = StructType(
StructField("word", StringType, nullable = false) ::
StructField("number", IntegerType, nullable = false) :: Nil)
val sortOrder: Seq[SortOrder] = Seq(
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
SortOrder(BoundReference(1, IntegerType, nullable = false), Descending))
val rowsToSort: Seq[Row] = Seq(
createRow("Hello", 9),
createRow("World", 4),
createRow("Hello", 7),
createRow("Skinny", 0),
createRow("Constantinople", 9))
SparkPlan.currentContext.set(sqlContext)
val input =
new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1))
// Treat the existing sort operators as the source-of-truth for this test
val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect()
val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect()
val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect()
assert (defaultSorted === externalSorted)
assert (unsafeSorted === externalSorted)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.test.TestHive

/**
* Runs the test cases that are included in the hive distribution with sort merge join and
* unsafe external sort enabled.
*/
class UnsafeSortMergeCompatibiltySuite extends SortMergeCompatibilitySuite {
override def beforeAll() {
super.beforeAll()
TestHive.setConf(SQLConf.CODEGEN_ENABLED, "true")
TestHive.setConf(SQLConf.UNSAFE_ENABLED, "true")
TestHive.setConf(SQLConf.EXTERNAL_SORT, "true")
}

override def afterAll() {
TestHive.setConf(SQLConf.CODEGEN_ENABLED, "false")
TestHive.setConf(SQLConf.UNSAFE_ENABLED, "false")
TestHive.setConf(SQLConf.EXTERNAL_SORT, "false")
super.afterAll()
}
}

0 comments on commit 58f36d0

Please sign in to comment.