Skip to content

Commit

Permalink
Test.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Apr 14, 2015
1 parent 8297732 commit 43b9fb4
Showing 1 changed file with 32 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ package org.apache.spark.sql.execution

import java.sql.{Timestamp, Date}

import org.apache.spark.serializer.Serializer
import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, SparkContext}
import org.scalatest.{FunSuite, BeforeAndAfterAll}

import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.ShuffleDependency
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.scalatest.{FunSuite, BeforeAndAfterAll}

import org.apache.spark.sql.{MyDenseVectorUDT, SQLContext, QueryTest}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}

class SparkSqlSerializer2DataTypeSuite extends FunSuite {
// Make sure that we will not use serializer2 for unsupported data types.
Expand Down Expand Up @@ -67,18 +68,17 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite {
}

abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {

@transient var sparkContext: SparkContext = _
@transient var sqlContext: SQLContext = _
// We may have an existing SparkEnv (e.g. the one used by TestSQLContext).
@transient val existingSparkEnv = SparkEnv.get
var allColumns: String = _
val serializerClass: Class[Serializer] =
classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]]
var numShufflePartitions: Int = _
var useSerializer2: Boolean = _

override def beforeAll(): Unit = {
sqlContext.sql("set spark.sql.shuffle.partitions=5")
sqlContext.sql("set spark.sql.useSerializer2=true")
numShufflePartitions = conf.numShufflePartitions
useSerializer2 = conf.useSqlSerializer2

sql("set spark.sql.useSerializer2=true")

val supportedTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
Expand Down Expand Up @@ -112,18 +112,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
new Timestamp(i))
}

sqlContext.createDataFrame(rdd, schema).registerTempTable("shuffle")
createDataFrame(rdd, schema).registerTempTable("shuffle")

super.beforeAll()
}

override def afterAll(): Unit = {
sqlContext.dropTempTable("shuffle")
sparkContext.stop()
sqlContext = null
sparkContext = null
// Set the existing SparkEnv back.
SparkEnv.set(existingSparkEnv)
dropTempTable("shuffle")
sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
sql(s"set spark.sql.useSerializer2=$useSerializer2")
super.afterAll()
}

Expand All @@ -144,64 +141,40 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
}

test("key schema and value schema are not nulls") {
val df = sqlContext.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
sqlContext.table("shuffle").collect())
table("shuffle").collect())
}

test("value schema is null") {
val df = sqlContext.sql(s"SELECT col0 FROM shuffle ORDER BY col0")
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
assert(
df.map(r => r.getString(0)).collect().toSeq ===
sqlContext.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
}
}

/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
super.beforeAll()
// Sort merge will not be triggered.
sql("set spark.sql.shuffle.partitions = 200")
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sqlContext.sql(s"SELECT $aggregations FROM shuffle")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
}
}

/** Tests SparkSqlSerializer2 with hash based shuffle. */
class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
val sparkConf =
new SparkConf()
.set("spark.driver.allowMultipleContexts", "true")
.set("spark.sql.testkey", "true")
.set("spark.shuffle.manager", "hash")

sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf)
sqlContext = new SQLContext(sparkContext)
super.beforeAll()
}
}

/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
// Since spark.sql.shuffle.partition is 5, we will not do sort merge when
// spark.shuffle.sort.bypassMergeThreshold is also 5.
val sparkConf =
new SparkConf()
.set("spark.driver.allowMultipleContexts", "true")
.set("spark.sql.testkey", "true")
.set("spark.shuffle.manager", "sort")
.set("spark.shuffle.sort.bypassMergeThreshold", "5")

sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf)
sqlContext = new SQLContext(sparkContext)
super.beforeAll()
}
}

/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {

Expand All @@ -210,15 +183,8 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]

override def beforeAll(): Unit = {
val sparkConf =
new SparkConf()
.set("spark.driver.allowMultipleContexts", "true")
.set("spark.sql.testkey", "true")
.set("spark.shuffle.manager", "sort")
.set("spark.shuffle.sort.bypassMergeThreshold", "0") // Always do sort merge.

sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf)
sqlContext = new SQLContext(sparkContext)
super.beforeAll()
// To trigger the sort merge.
sql("set spark.sql.shuffle.partitions = 201")
}
}

0 comments on commit 43b9fb4

Please sign in to comment.