diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 077b035f3d079..9d76611049b76 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -28,8 +28,10 @@ import org.apache.spark.internal.config._ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( - out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) - extends SerializationStream { + out: OutputStream, + counterReset: Int, + extraDebugInfo: Boolean) + extends SerializationStream { private val objOut = new ObjectOutputStream(out) private var counter = 0 @@ -59,9 +61,10 @@ private[spark] class JavaSerializationStream( } private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) - extends DeserializationStream { + extends DeserializationStream { private val objIn = new ObjectInputStream(in) { + override def resolveClass(desc: ObjectStreamClass): Class[_] = try { // scalastyle:off classforname @@ -71,6 +74,14 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa case e: ClassNotFoundException => JavaDeserializationStream.primitiveMappings.getOrElse(desc.getName, throw e) } + + override def resolveProxyClass(ifaces: Array[String]): Class[_] = { + // scalastyle:off classforname + val resolved = ifaces.map(iface => Class.forName(iface, false, loader)) + // scalastyle:on classforname + java.lang.reflect.Proxy.getProxyClass(loader, resolved: _*) + } + } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] @@ -78,6 +89,7 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa } private object JavaDeserializationStream { + val primitiveMappings = Map[String, Class[_]]( "boolean" -> classOf[Boolean], "byte" -> classOf[Byte], @@ -87,13 +99,15 @@ private object JavaDeserializationStream { "long" -> classOf[Long], "float" -> classOf[Float], "double" -> classOf[Double], - "void" -> classOf[Void] - ) + "void" -> classOf[Void]) + } private[spark] class JavaSerializerInstance( - counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) - extends SerializerInstance { + counterReset: Int, + extraDebugInfo: Boolean, + defaultClassLoader: ClassLoader) + extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteBufferOutputStream() @@ -126,6 +140,7 @@ private[spark] class JavaSerializerInstance( def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { new JavaDeserializationStream(s, loader) } + } /** @@ -141,20 +156,23 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.get(SERIALIZER_OBJECT_STREAM_RESET) private var extraDebugInfo = conf.get(SERIALIZER_EXTRA_DEBUG_INFO) - protected def this() = this(new SparkConf()) // For deserialization only + protected def this() = this(new SparkConf()) // For deserialization only override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) } - override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(counterReset) - out.writeBoolean(extraDebugInfo) - } + override def writeExternal(out: ObjectOutput): Unit = + Utils.tryOrIOException { + out.writeInt(counterReset) + out.writeBoolean(extraDebugInfo) + } + + override def readExternal(in: ObjectInput): Unit = + Utils.tryOrIOException { + counterReset = in.readInt() + extraDebugInfo = in.readBoolean() + } - override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - counterReset = in.readInt() - extraDebugInfo = in.readBoolean() - } } diff --git a/core/src/test/java/org/apache/spark/serializer/ContainsProxyClass.java b/core/src/test/java/org/apache/spark/serializer/ContainsProxyClass.java new file mode 100644 index 0000000000000..66b2ba41cd324 --- /dev/null +++ b/core/src/test/java/org/apache/spark/serializer/ContainsProxyClass.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.Serializable; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; + +class ContainsProxyClass implements Serializable { + final MyInterface proxy = (MyInterface) Proxy.newProxyInstance( + MyInterface.class.getClassLoader(), + new Class[]{MyInterface.class}, + new MyInvocationHandler()); + + // Interface needs to be public as classloaders will mismatch. + // See ObjectInputStream#resolveProxyClass for details. + public interface MyInterface { + void myMethod(); + } + + static class MyClass implements MyInterface, Serializable { + @Override + public void myMethod() {} + } + + class MyInvocationHandler implements InvocationHandler, Serializable { + private final MyClass real = new MyClass(); + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + return method.invoke(real, args); + } + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 6a6ea42797fb6..77226afc4eff7 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -31,11 +31,33 @@ class JavaSerializerSuite extends SparkFunSuite { test("Deserialize object containing a primitive Class as attribute") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() - val obj = instance.deserialize[ContainsPrimitiveClass](instance.serialize( - new ContainsPrimitiveClass())) + val obj = instance.deserialize[ContainsPrimitiveClass]( + instance.serialize(new ContainsPrimitiveClass())) // enforce class cast obj.getClass } + + test("SPARK-36627: Deserialize object containing a proxy Class as attribute") { + var classesLoaded = Set[String]() + val outer = Thread.currentThread.getContextClassLoader + val inner = new ClassLoader() { + override def loadClass(name: String): Class[_] = { + classesLoaded = classesLoaded + name + outer.loadClass(name) + } + } + Thread.currentThread.setContextClassLoader(inner) + + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + val obj = + instance.deserialize[ContainsProxyClass](instance.serialize(new ContainsProxyClass())) + // enforce class cast + obj.getClass + + // check that serializer's loader is used to resolve proxied interface. + assert(classesLoaded.exists(klass => klass.contains("MyInterface"))) + } } private class ContainsPrimitiveClass extends Serializable {