diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index f9f78852f032b..64ba27f34d2f1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -32,6 +32,7 @@ import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} +import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.BoundedPriorityQueue @@ -51,18 +52,18 @@ class KryoSerializer(conf: SparkConf) private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") - if (bufferSizeKb >= 2048) { + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + - s"2048 mb, got: + $bufferSizeKb mb.") + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") } - private val bufferSize = (bufferSizeKb * 1024).toInt + private val bufferSize = ByteUnit.KiB.toBytes(bufferSizeKb).toInt val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt - if (maxBufferSizeMb >= 2048) { + if (maxBufferSizeMb >= ByteUnit.GiB.toMiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + s"2048 mb, got: + $maxBufferSizeMb mb.") } - private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 + private val maxBufferSize = ByteUnit.MiB.toBytes(maxBufferSizeMb).toInt private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 778a7eee73b23..c7369de24b81f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -32,6 +32,36 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) + test("configuration limits") { + val conf1 = conf.clone() + val kryoBufferProperty = "spark.kryoserializer.buffer" + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" + conf1.set(kryoBufferProperty, "64k") + conf1.set(kryoBufferMaxProperty, "64m") + new KryoSerializer(conf1).newInstance() + // 2048m = 2097152k + conf1.set(kryoBufferProperty, "2097151k") + conf1.set(kryoBufferMaxProperty, "64m") + // should not throw exception when kryoBufferMaxProperty < kryoBufferProperty + new KryoSerializer(conf1).newInstance() + conf1.set(kryoBufferMaxProperty, "2097151k") + new KryoSerializer(conf1).newInstance() + val conf2 = conf.clone() + conf2.set(kryoBufferProperty, "2048m") + val thrown1 = intercept[IllegalArgumentException](new KryoSerializer(conf2).newInstance()) + assert(thrown1.getMessage.contains(kryoBufferProperty)) + val conf3 = conf.clone() + conf3.set(kryoBufferMaxProperty, "2048m") + val thrown2 = intercept[IllegalArgumentException](new KryoSerializer(conf3).newInstance()) + assert(thrown2.getMessage.contains(kryoBufferMaxProperty)) + val conf4 = conf.clone() + conf4.set(kryoBufferProperty, "2g") + conf4.set(kryoBufferMaxProperty, "3g") + val thrown3 = intercept[IllegalArgumentException](new KryoSerializer(conf4).newInstance()) + assert(thrown3.getMessage.contains(kryoBufferProperty)) + assert(!thrown3.getMessage.contains(kryoBufferMaxProperty)) + } + test("basic types") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) {