Skip to content

Commit

Permalink
[SPARK-28112][TEST] Fix Kryo exception perf. bottleneck in tests due …
Browse files Browse the repository at this point in the history
…to absence of ML/MLlib classes

## What changes were proposed in this pull request?
In a nutshell, it looks like the absence of ML / MLlib classes on the classpath causes code in KryoSerializer to throw and catch ClassNotFoundExceptions whenever instantiating a new serializer in newInstance(). This isn't a performance problem in production (since MLlib is on the classpath there) but it's a huge issue in tests and appears to account for an enormous amount of test time

We can address this problem by reducing the total number of ClassNotFoundExceptions by performing the class existence checks once and storing the results in KryoSerializer instances rather than repeating the checks on each newInstance() call.

## How was this patch tested?
The existing tests.

Authored-by: Josh Rosen <joshrosendatabricks.com>

Closes #24916 from gatorsmile/kryoException.

Lead-authored-by: Josh Rosen <rosenville@gmail.com>
Co-authored-by: gatorsmile <gatorsmile@gmail.com>
Signed-off-by: Josh Rosen <rosenville@gmail.com>
  • Loading branch information
JoshRosen and gatorsmile committed Jun 20, 2019
1 parent 6b27ad5 commit ec032ce
Showing 1 changed file with 45 additions and 33 deletions.
Expand Up @@ -212,40 +212,8 @@ class KryoSerializer(conf: SparkConf)

// We can't load those class directly in order to avoid unnecessary jar dependencies.
// We load them safely, ignore it if the class not found.
Seq(
"org.apache.spark.sql.catalyst.expressions.UnsafeRow",
"org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
"org.apache.spark.sql.catalyst.expressions.UnsafeMapData",

"org.apache.spark.ml.attribute.Attribute",
"org.apache.spark.ml.attribute.AttributeGroup",
"org.apache.spark.ml.attribute.BinaryAttribute",
"org.apache.spark.ml.attribute.NominalAttribute",
"org.apache.spark.ml.attribute.NumericAttribute",

"org.apache.spark.ml.feature.Instance",
"org.apache.spark.ml.feature.LabeledPoint",
"org.apache.spark.ml.feature.OffsetInstance",
"org.apache.spark.ml.linalg.DenseMatrix",
"org.apache.spark.ml.linalg.DenseVector",
"org.apache.spark.ml.linalg.Matrix",
"org.apache.spark.ml.linalg.SparseMatrix",
"org.apache.spark.ml.linalg.SparseVector",
"org.apache.spark.ml.linalg.Vector",
"org.apache.spark.ml.stat.distribution.MultivariateGaussian",
"org.apache.spark.ml.tree.impl.TreePoint",
"org.apache.spark.mllib.clustering.VectorWithNorm",
"org.apache.spark.mllib.linalg.DenseMatrix",
"org.apache.spark.mllib.linalg.DenseVector",
"org.apache.spark.mllib.linalg.Matrix",
"org.apache.spark.mllib.linalg.SparseMatrix",
"org.apache.spark.mllib.linalg.SparseVector",
"org.apache.spark.mllib.linalg.Vector",
"org.apache.spark.mllib.regression.LabeledPoint",
"org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
).foreach { name =>
KryoSerializer.loadableSparkClasses.foreach { clazz =>
try {
val clazz = Utils.classForName(name)
kryo.register(clazz)
} catch {
case NonFatal(_) => // do nothing
Expand Down Expand Up @@ -516,6 +484,50 @@ private[serializer] object KryoSerializer {
}
}
)

// classForName() is expensive in case the class is not found, so we filter the list of
// SQL / ML / MLlib classes once and then re-use that filtered list in newInstance() calls.
private lazy val loadableSparkClasses: Seq[Class[_]] = {
Seq(
"org.apache.spark.sql.catalyst.expressions.UnsafeRow",
"org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
"org.apache.spark.sql.catalyst.expressions.UnsafeMapData",

"org.apache.spark.ml.attribute.Attribute",
"org.apache.spark.ml.attribute.AttributeGroup",
"org.apache.spark.ml.attribute.BinaryAttribute",
"org.apache.spark.ml.attribute.NominalAttribute",
"org.apache.spark.ml.attribute.NumericAttribute",

"org.apache.spark.ml.feature.Instance",
"org.apache.spark.ml.feature.LabeledPoint",
"org.apache.spark.ml.feature.OffsetInstance",
"org.apache.spark.ml.linalg.DenseMatrix",
"org.apache.spark.ml.linalg.DenseVector",
"org.apache.spark.ml.linalg.Matrix",
"org.apache.spark.ml.linalg.SparseMatrix",
"org.apache.spark.ml.linalg.SparseVector",
"org.apache.spark.ml.linalg.Vector",
"org.apache.spark.ml.stat.distribution.MultivariateGaussian",
"org.apache.spark.ml.tree.impl.TreePoint",
"org.apache.spark.mllib.clustering.VectorWithNorm",
"org.apache.spark.mllib.linalg.DenseMatrix",
"org.apache.spark.mllib.linalg.DenseVector",
"org.apache.spark.mllib.linalg.Matrix",
"org.apache.spark.mllib.linalg.SparseMatrix",
"org.apache.spark.mllib.linalg.SparseVector",
"org.apache.spark.mllib.linalg.Vector",
"org.apache.spark.mllib.regression.LabeledPoint",
"org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
).flatMap { name =>
try {
Some[Class[_]](Utils.classForName(name))
} catch {
case NonFatal(_) => None // do nothing
case _: NoClassDefFoundError if Utils.isTesting => None // See SPARK-23422.
}
}
}
}

/**
Expand Down

0 comments on commit ec032ce

Please sign in to comment.