Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes SPARK-8730 - Deser objects containing a primitive class attribute #7122

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,34 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
extends DeserializationStream {

private val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass): Class[_] = {
// scalastyle:off classforname
Class.forName(desc.getName, false, loader)
// scalastyle:on classforname
}
override def resolveClass(desc: ObjectStreamClass): Class[_] =
try {
// scalastyle:off classforname
Class.forName(desc.getName, false, loader)
// scalastyle:on classforname
} catch {
case e: ClassNotFoundException =>
JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e)
}
}

def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
def close() { objIn.close() }
}

private object JavaDeserializationStream {
val primitiveMappings = Map[String, Class[_]](
"boolean" -> classOf[Boolean],
"byte" -> classOf[Byte],
"char" -> classOf[Char],
"short" -> classOf[Short],
"int" -> classOf[Int],
"long" -> classOf[Long],
"float" -> classOf[Float],
"double" -> classOf[Double],
"void" -> classOf[Void]
)
}

private[spark] class JavaSerializerInstance(
counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,22 @@ class JavaSerializerSuite extends SparkFunSuite {
val instance = serializer.newInstance()
instance.deserialize[JavaSerializer](instance.serialize(serializer))
}

test("Deserialize object containing a primitive Class as attribute") {
val serializer = new JavaSerializer(new SparkConf())
val instance = serializer.newInstance()
instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass()))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to have a test that we still get the original ClassNotFound in other cases, but I'm not sure how to do that ... perhaps you could write a test where you serialize with a special classloader which loads an extra class, and then deserialize with a different classloader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would require the other classloader to not have access to that class (so wouldn't work with the existing classes). The only way I see would be to generate some class at runtime (via asm or whatever) and make it available only to the classloader used during serialization. But that sounds kind of overkill to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@squito took care of your comments except this one. wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, you're right, I don't see a good way around this. I'm OK with this in that case, but lets leave this open for a bit to see if anybody else has any ideas

}

private class ContainsPrimitiveClass extends Serializable {
val intClass = classOf[Int]
val longClass = classOf[Long]
val shortClass = classOf[Short]
val charClass = classOf[Char]
val doubleClass = classOf[Double]
val floatClass = classOf[Float]
val booleanClass = classOf[Boolean]
val byteClass = classOf[Byte]
val voidClass = classOf[Void]
}