Skip to content

Commit

Permalink
Port test to SparkPlanTest
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 6, 2015
1 parent d468a88 commit 7eafecf
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ case class UnsafeExternalSort(
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
assert (codegenEnabled)
assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val ordering = newOrdering(sortOrder, child.output)
val prefixComparator = new PrefixComparator {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,48 @@

package org.apache.spark.sql.execution

import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.{SQLConf, SQLContext, Row}
import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext

class UnsafeExternalSortSuite extends SparkFunSuite with Matchers {
import scala.util.Random

class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {

override def beforeAll(): Unit = {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
}

override def afterAll(): Unit = {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
}

private def createRow(values: Any*): Row = {
new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
}

test("basic sorting") {
val sc = TestSQLContext.sparkContext
val sqlContext = new SQLContext(sc)
sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true")

val schema: StructType = StructType(
StructField("word", StringType, nullable = false) ::
StructField("number", IntegerType, nullable = false) :: Nil)
val inputData = Seq(
("Hello", 9),
("World", 4),
("Hello", 7),
("Skinny", 0),
("Constantinople", 9)
)

val sortOrder: Seq[SortOrder] = Seq(
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
SortOrder(BoundReference(1, IntegerType, nullable = false), Descending))
val rowsToSort: Seq[Row] = Seq(
createRow("Hello", 9),
createRow("World", 4),
createRow("Hello", 7),
createRow("Skinny", 0),
createRow("Constantinople", 9))
SparkPlan.currentContext.set(sqlContext)
val input =
new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1))
// Treat the existing sort operators as the source-of-truth for this test
val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect()
val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect()
val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect()
assert (defaultSorted === externalSorted)
assert (unsafeSorted === externalSorted)

checkAnswer(
Random.shuffle(inputData),
(input: SparkPlan) => new UnsafeExternalSort(sortOrder, global = false, input),
inputData
)
}
}

0 comments on commit 7eafecf

Please sign in to comment.