diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 9f5c11deeb4dd..9e152ca472e51 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -20,6 +20,8 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer +import org.apache.spark.io.CompressionCodec + import scala.collection.mutable import org.apache.commons.io.IOUtils @@ -39,7 +41,7 @@ import org.apache.avro.io._ * Actions like parsing or compressing schemas are computationally expensive so the serializer * caches all previously seen values as to reduce the amount of work needed to do. */ -private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String], codec: CompressionCodec) extends KSerializer[GenericRecord] { /** Used to reduce the amount of effort to compress the schema */ @@ -61,7 +63,7 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) */ def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { val bos = new ByteArrayOutputStream() - val out = new SnappyOutputStream(bos) + val out = codec.compressedOutputStream(bos) out.write(schema.toString.getBytes("UTF-8")) out.close() bos.toByteArray @@ -73,7 +75,7 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) */ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { val bis = new ByteArrayInputStream(schemaBytes.array()) - val bytes = IOUtils.toByteArray(new SnappyInputStream(bis)) + val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) new Schema.Parser().parse(new String(bytes, "UTF-8")) }) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3107a735e2e53..19a6077268332 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -21,6 +21,8 @@ import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable +import org.apache.spark.io.CompressionCodec + import scala.reflect.ClassTag import org.apache.avro.generic.{GenericData, GenericRecord} @@ -76,6 +78,7 @@ class KryoSerializer(conf: SparkConf) .filter(!_.isEmpty) private val avroSchemas = conf.getAvroSchema + private val codec = CompressionCodec.createCodec(conf) def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) @@ -103,8 +106,8 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) - kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) - kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas, codec)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas, codec)) try { // Use the default classloader when calling the user registrator. diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index bc9f3708ed69d..9dc1df4f15165 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import com.esotericsoftware.kryo.io.{Output, Input} import org.apache.avro.{SchemaBuilder, Schema} import org.apache.avro.generic.GenericData.Record +import org.apache.spark.io.CompressionCodec import org.apache.spark.{SparkFunSuite, SharedSparkContext} @@ -37,12 +38,14 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { record.put("data", "test data") test("schema compression and decompression") { - val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema, + CompressionCodec.createCodec(conf)) assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) } test("record serialization and deserialization") { - val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema, + CompressionCodec.createCodec(conf)) val outputStream = new ByteArrayOutputStream() val output = new Output(outputStream) @@ -55,7 +58,8 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("uses schema fingerprint to decrease message size") { - val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema, + CompressionCodec.createCodec(conf)) val output = new Output(new ByteArrayOutputStream()) @@ -65,7 +69,8 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { val normalLength = output.total - beginningNormalPosition conf.registerAvroSchemas(schema) - val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema, + CompressionCodec.createCodec(conf)) val beginningFingerprintPosition = output.total() genericSerFinger.serializeDatum(record, output) val fingerprintLength = output.total - beginningFingerprintPosition @@ -74,7 +79,8 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("caches previously seen schemas") { - val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema, + CompressionCodec.createCodec(conf)) val compressedSchema = genericSer.compress(schema) val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema))