diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 9e152ca472e51..5d90ff10ed9c8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -20,6 +20,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer +import org.apache.spark.SparkConf import org.apache.spark.io.CompressionCodec import scala.collection.mutable @@ -41,7 +42,7 @@ import org.apache.avro.io._ * Actions like parsing or compressing schemas are computationally expensive so the serializer * caches all previously seen values as to reduce the amount of work needed to do. */ -private[serializer] class GenericAvroSerializer(schemas: Map[Long, String], codec: CompressionCodec) +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) extends KSerializer[GenericRecord] { /** Used to reduce the amount of effort to compress the schema */ @@ -56,6 +57,8 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String], code private val fingerprintCache = new mutable.HashMap[Schema, Long]() private val schemaCache = new mutable.HashMap[Long, Schema]() + private val codec = CompressionCodec.createCodec(new SparkConf()) + /** * Used to compress Schemas when they are being sent over the wire. * The compression results are memoized to reduce the compression time since the diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 19a6077268332..3107a735e2e53 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -21,8 +21,6 @@ import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable -import org.apache.spark.io.CompressionCodec - import scala.reflect.ClassTag import org.apache.avro.generic.{GenericData, GenericRecord} @@ -78,7 +76,6 @@ class KryoSerializer(conf: SparkConf) .filter(!_.isEmpty) private val avroSchemas = conf.getAvroSchema - private val codec = CompressionCodec.createCodec(conf) def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) @@ -106,8 +103,8 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) - kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas, codec)) - kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas, codec)) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) try { // Use the default classloader when calling the user registrator. diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index 9dc1df4f15165..b7f72da666fef 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -38,14 +38,12 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { record.put("data", "test data") test("schema compression and decompression") { - val genericSer = new GenericAvroSerializer(conf.getAvroSchema, - CompressionCodec.createCodec(conf)) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) } test("record serialization and deserialization") { - val genericSer = new GenericAvroSerializer(conf.getAvroSchema, - CompressionCodec.createCodec(conf)) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val outputStream = new ByteArrayOutputStream() val output = new Output(outputStream) @@ -58,8 +56,7 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("uses schema fingerprint to decrease message size") { - val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema, - CompressionCodec.createCodec(conf)) + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) val output = new Output(new ByteArrayOutputStream()) @@ -69,8 +66,7 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { val normalLength = output.total - beginningNormalPosition conf.registerAvroSchemas(schema) - val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema, - CompressionCodec.createCodec(conf)) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) val beginningFingerprintPosition = output.total() genericSerFinger.serializeDatum(record, output) val fingerprintLength = output.total - beginningFingerprintPosition @@ -79,8 +75,7 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("caches previously seen schemas") { - val genericSer = new GenericAvroSerializer(conf.getAvroSchema, - CompressionCodec.createCodec(conf)) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema))