Skip to content

Commit

Permalink
[SPARK-3046] use executor's class loader as the default serializer cl…
Browse files Browse the repository at this point in the history
…assloader

The serializer is not always used in an executor thread (e.g. connection manager, broadcast), in which case the classloader might not have the user jar set, leading to corruption in deserialization.

https://issues.apache.org/jira/browse/SPARK-3046

https://issues.apache.org/jira/browse/SPARK-2878

Author: Reynold Xin <rxin@apache.org>

Closes #1972 from rxin/kryoBug and squashes the following commits:

c1c7bf0 [Reynold Xin] Made change to JavaSerializer.
7204c33 [Reynold Xin] Added imports back.
d879e67 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer class loader.
  • Loading branch information
rxin committed Aug 16, 2014
1 parent c703229 commit cc36487
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 4 deletions.
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ private[spark] class Executor(
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)

// Set the classloader for serializer
env.serializer.setDefaultClassLoader(urlClassLoader)

// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ extends DeserializationStream {
def close() { objIn.close() }
}

private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
Expand Down Expand Up @@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)

def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
override def newInstance(): SerializerInstance = {
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
new JavaSerializerInstance(counterReset, classLoader)
}

override def writeExternal(out: ObjectOutput) {
out.writeInt(counterReset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
val classLoader = Thread.currentThread.getContextClassLoader

val oldClassLoader = Thread.currentThread.getContextClassLoader
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)

// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
Expand All @@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
try {
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]

// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
} finally {
Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}

Expand Down
17 changes: 17 additions & 0 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
*/
@DeveloperApi
trait Serializer {

/**
* Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
* make sure it is using this when set.
*/
@volatile protected var defaultClassLoader: Option[ClassLoader] = None

/**
* Sets a class loader for the serializer to use in deserialization.
*
* @return this Serializer object
*/
def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
defaultClassLoader = Some(classLoader)
this
}

def newInstance(): SerializerInstance
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import org.apache.spark.util.Utils

import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils}
import org.apache.spark.SparkContext._
import org.apache.spark.serializer.KryoDistributedTest._

class KryoSerializerDistributedSuite extends FunSuite {

test("kryo objects are serialised consistently in different processes") {
val conf = new SparkConf(false)
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
conf.set("spark.task.maxFailures", "1")

val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
conf.setJars(List(jar.getPath))

val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val original = Thread.currentThread.getContextClassLoader
val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
SparkEnv.get.serializer.setDefaultClassLoader(loader)

val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()

// Randomly mix the keys so that the join below will require a shuffle with each partition
// sending data to multiple other partitions.
val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}

// Join the two RDDs, and force evaluation
assert(shuffledRDD.join(cachedRDD).collect().size == 1)

LocalSparkContext.stop(sc)
}
}

object KryoDistributedTest {
class MyCustomClass

class AppJarRegistrator extends KryoRegistrator {
override def registerClasses(k: Kryo) {
val classLoader = Thread.currentThread.getContextClassLoader
k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
}
}

object AppJarRegistrator {
val customClassName = "KryoSerializerDistributedSuiteCustomClass"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite

import org.apache.spark.SharedSparkContext
import org.apache.spark.{SparkConf, SharedSparkContext}
import org.apache.spark.serializer.KryoTest._

class KryoSerializerSuite extends FunSuite with SharedSparkContext {
Expand Down Expand Up @@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
}

test("default class loader can be set by a different thread") {
val ser = new KryoSerializer(new SparkConf)

// First serialize the object
val serInstance = ser.newInstance()
val bytes = serInstance.serialize(new ClassLoaderTestingObject)

// Deserialize the object to make sure normal deserialization works
serInstance.deserialize[ClassLoaderTestingObject](bytes)

// Set a special, broken ClassLoader and make sure we get an exception on deserialization
ser.setDefaultClassLoader(new ClassLoader() {
override def loadClass(name: String) = throw new UnsupportedOperationException
})
intercept[UnsupportedOperationException] {
ser.newInstance().deserialize[ClassLoaderTestingObject](bytes)
}
}
}

class ClassLoaderTestingObject

class KryoSerializerResizableOutputSuite extends FunSuite {
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
Expand Down

0 comments on commit cc36487

Please sign in to comment.