From 39704abb09271dc38d6afbc87652235add0ea32e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 13 Apr 2015 14:04:04 -0700 Subject: [PATCH] Specialized serializer for Exchange. --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 + .../apache/spark/sql/execution/Exchange.scala | 39 +- .../sql/execution/SparkSqlSerializer2.scala | 378 ++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 3 + .../execution/SparkSqlSerializer2Suite.scala | 198 +++++++++ 5 files changed, 618 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ee641bdfeb2d7..1070bf44abd6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -63,6 +63,8 @@ private[spark] object SQLConf { // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -139,6 +141,8 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "false").toBoolean + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..b69b89d74c099 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} @@ -45,6 +46,27 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + def serializer( + keySchema: Array[DataType], + valueSchema: Array[DataType], + numPartitions: Int): Serializer = { + val useSqlSerializer2 = + !(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) && + child.sqlContext.conf.useSqlSerializer2 && + SparkSqlSerializer2.support(keySchema) && + SparkSqlSerializer2.support(valueSchema) + + val serializer = if (useSqlSerializer2) { + logInfo("Use ShuffleSerializer") + new SparkSqlSerializer2(keySchema, valueSchema) + } else { + logInfo("Use SparkSqlSerializer") + new SparkSqlSerializer(new SparkConf(false)) + } + + serializer + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => @@ -70,7 +92,11 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + val keySchema = expressions.map(_.dataType).toArray + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) + shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => @@ -88,7 +114,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una val part = new RangePartitioner(numPartitions, rdd, ascending = true) val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + val keySchema = sortingExpressions.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, null, numPartitions)) shuffled.map(_._1) @@ -107,7 +135,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(null, valueSchema, 1)) + shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala new file mode 100644 index 0000000000000..9e7b4ab63fe79 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.nio.ByteBuffer +import java.sql.Timestamp + +import scala.reflect.ClassTag + +import org.apache.spark.serializer._ +import org.apache.spark.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ + +/** + * The serialization stream for SparkSqlSerializer2. + */ +private[sql] class Serializer2SerializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + out: OutputStream) + extends SerializationStream with Logging { + + val rowOut = new DataOutputStream(out) + val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + + def writeObject[T: ClassTag](t: T): SerializationStream = { + val kv = t.asInstanceOf[Product2[Row, Row]] + writeKey(kv._1) + writeValue(kv._2) + + this + } + + def flush(): Unit = { + rowOut.flush() + } + + def close(): Unit = { + rowOut.close() + } +} + +/** + * The deserialization stream for SparkSqlSerializer2. + */ +private[sql] class Serializer2DeserializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + in: InputStream) + extends DeserializationStream with Logging { + + val rowIn = new DataInputStream(new BufferedInputStream(in)) + + val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null + val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null + val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + + def readObject[T: ClassTag](): T = { + readKey() + readValue() + + (key, value).asInstanceOf[T] + } + + def close(): Unit = { + rowIn.close() + } +} + +private[sql] class ShuffleSerializerInstance( + keySchema: Array[DataType], + valueSchema: Array[DataType]) + extends SerializerInstance { + + def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException("Not supported.") + + def serializeStream(s: OutputStream): SerializationStream = { + new Serializer2SerializationStream(keySchema, valueSchema, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new Serializer2DeserializationStream(keySchema, valueSchema, s) + } +} + +/** + * SparkSqlSerializer2 is a special serializer that creates serialization function and + * deserialization function based on the schema of data. It assumes that values passed in + * are key/value pairs and values returned from it are also key/value pairs. + * The schema of keys is represented by `keySchema` and that of values is represented by + * `valueSchema`. + */ +private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) + extends Serializer + with Logging + with Serializable{ + + def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) +} + +private[sql] object SparkSqlSerializer2 { + + final val NULL = 0 + final val NOT_NULL = 1 + + /** + * Check if rows with the given schema can be serialized with ShuffleSerializer. + */ + def support(schema: Array[DataType]): Boolean = { + if (schema == null) return true + + var i = 0 + while (i < schema.length) { + schema(i) match { + case udt: UserDefinedType[_] => return false + case array: ArrayType => return false + case map: MapType => return false + case struct: StructType => return false + case decimal: DecimalType => return false + case _ => + } + i += 1 + } + + return true + } + + /** + * The util function to create the serialization function based on the given schema. + */ + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { + (row: Row) => + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we write values to the underlying stream, we also first write the null byte + // first. Then, if the value is not null, we write the contents out. + + case NullType => // Write nothing. + + case BooleanType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeBoolean(row.getBoolean(i)) + } + + case ByteType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeByte(row.getByte(i)) + } + + case ShortType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeShort(row.getShort(i)) + } + + case IntegerType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case LongType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeLong(row.getLong(i)) + } + + case FloatType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeFloat(row.getFloat(i)) + } + + case DoubleType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeDouble(row.getDouble(i)) + } + + case DateType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case TimestampType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val timestamp = row.getAs[java.sql.Timestamp](i) + val time = timestamp.getTime + val nanos = timestamp.getNanos + out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. + out.writeInt(nanos) // Write the nanoseconds part. + } + + case StringType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + // TODO: Update it once the string improvement is in. + out.writeUTF(row.getString(i)) + } + + case BinaryType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[Array[Byte]](i) + out.writeInt(bytes.length) + out.write(bytes) + } + } + i += 1 + } + } + } + + /** + * The util function to create the deserialization function based on the given schema. + */ + def createDeserializationFunction( + schema: Array[DataType], + in: DataInputStream, + mutableRow: SpecificMutableRow): () => Unit = { + () => { + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we read values from the underlying stream, we also first read the null byte + // first. Then, if the value is not null, we update the field of the mutable row. + + case NullType => mutableRow.setNullAt(i) // Read nothing. + + case BooleanType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, in.readBoolean()) + } + + case ByteType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setByte(i, in.readByte()) + } + + case ShortType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setShort(i, in.readShort()) + } + + case IntegerType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setInt(i, in.readInt()) + } + + case LongType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setLong(i, in.readLong()) + } + + case FloatType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, in.readFloat()) + } + + case DoubleType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, in.readDouble()) + } + + case DateType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.update(i, in.readInt()) + } + + case TimestampType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val time = in.readLong() // Read the milliseconds value. + val nanos = in.readInt() // Read the nanoseconds part. + val timestamp = new Timestamp(time) + timestamp.setNanos(nanos) + mutableRow.update(i, timestamp) + } + + case StringType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + // TODO: Update it once the string improvement is in. + mutableRow.setString(i, in.readUTF()) + } + + case BinaryType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, bytes) + } + } + i += 1 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9a81fc5d72819..59f9508444f25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -104,9 +104,12 @@ object QueryTest { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala new file mode 100644 index 0000000000000..335cca219931e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -0,0 +1,198 @@ +package org.apache.spark.sql.execution + +import java.sql.{Timestamp, Date} + +import org.apache.spark.serializer.Serializer +import org.apache.spark.{SparkConf, ShuffleDependency, SparkContext} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +import org.apache.spark.sql.{MyDenseVectorUDT, SQLContext, QueryTest} + +class SparkSqlSerializer2DataTypeSuite extends FunSuite { + // Make sure that we will not use serializer2 for unsupported data types. + def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { + val testName = + s"${if (dataType == null) null else dataType.toString} is " + + s"${if (isSupported) "supported" else "unsupported"}" + + test(testName) { + assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) + } + } + + checkSupported(null, isSupported = true) + checkSupported(NullType, isSupported = true) + checkSupported(BooleanType, isSupported = true) + checkSupported(ByteType, isSupported = true) + checkSupported(ShortType, isSupported = true) + checkSupported(IntegerType, isSupported = true) + checkSupported(LongType, isSupported = true) + checkSupported(FloatType, isSupported = true) + checkSupported(DoubleType, isSupported = true) + checkSupported(DateType, isSupported = true) + checkSupported(TimestampType, isSupported = true) + checkSupported(StringType, isSupported = true) + checkSupported(BinaryType, isSupported = true) + + // Because at the runtime we accepts three kinds of Decimals + // (Java BigDecimal, Scala BigDecimal, and Spark SQL's Decimal), we do support DecimalType + // right now. We will support it once we fixed the internal type. + checkSupported(DecimalType(10, 5), isSupported = false) + checkSupported(DecimalType.Unlimited, isSupported = false) + // For now, ArrayType, MapType, and StructType are not supported. + checkSupported(ArrayType(DoubleType, true), isSupported = false) + checkSupported(ArrayType(StringType, false), isSupported = false) + checkSupported(MapType(IntegerType, StringType, true), isSupported = false) + checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) + checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) + // UDTs are not supported right now. + checkSupported(new MyDenseVectorUDT, isSupported = false) +} + +abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { + + @transient var sparkContext: SparkContext = _ + @transient var sqlContext: SQLContext = _ + var allColumns: String = _ + val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + sqlContext.sql("set spark.sql.shuffle.partitions=5") + sqlContext.sql("set spark.sql.useSerializer2=true") + + val supportedTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, TimestampType) + + val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD with all data types supported by SparkSqlSerializer2. + val rdd = + sparkContext.parallelize((1 to 1000), 10).map { i => + Row( + s"str${i}: test serializer2.", + s"binary${i}: test serializer2.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + new Date(i), + new Timestamp(i)) + } + + sqlContext.createDataFrame(rdd, schema).registerTempTable("shuffle") + + super.beforeAll() + } + + override def afterAll(): Unit = { + sqlContext.dropTempTable("shuffle") + sparkContext.stop() + super.afterAll() + } + + def checkSerializer[T <: Serializer]( + executedPlan: SparkPlan, + expectedSerializerClass: Class[T]): Unit = { + executedPlan.foreach { + case exchange: Exchange => + val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] + val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val serializerNotSetMessage = + s"Expected $expectedSerializerClass as the serializer of Exchange. " + + s"However, the serializer was not set." + val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) + assert(serializer.getClass === expectedSerializerClass) + case _ => // Ignore other nodes. + } + } + + test("key schema and value schema are not nulls") { + val df = sqlContext.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + sqlContext.table("shuffle").collect()) + } + + test("value schema is null") { + val df = sqlContext.sql(s"SELECT col0 FROM shuffle ORDER BY col0") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + assert( + df.map(r => r.getString(0)).collect().toSeq === + sqlContext.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + } + + test("key schema is null") { + val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") + val df = sqlContext.sql(s"SELECT $aggregations FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + } +} + +/** Tests SparkSqlSerializer2 with hash based shuffle. */ +class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + val sparkConf = + new SparkConf() + .set("spark.sql.testkey", "true") + .set("spark.shuffle.manager", "hash") + + sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) + sqlContext = new SQLContext(sparkContext) + super.beforeAll() + } +} + +/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ +class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + // Since spark.sql.shuffle.partition is 5, we will not do sort merge when + // spark.shuffle.sort.bypassMergeThreshold is also 5. + val sparkConf = + new SparkConf() + .set("spark.sql.testkey", "true") + .set("spark.shuffle.manager", "sort") + .set("spark.shuffle.sort.bypassMergeThreshold", "5") + + sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) + sqlContext = new SQLContext(sparkContext) + super.beforeAll() + } +} + +/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ +class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { + + // We are expecting SparkSqlSerializer. + override val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + val sparkConf = + new SparkConf() + .set("spark.sql.testkey", "true") + .set("spark.shuffle.manager", "sort") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") // Always do sort merge. + + sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) + sqlContext = new SQLContext(sparkContext) + super.beforeAll() + } +}