Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8962] Add Scalastyle rule to ban direct use of Class.forName; fix existing uses #7350

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private object Logging {
try {
// We use reflection here to handle the case where users remove the
// slf4j-to-jul bridge order to route their logs to JUL.
val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler")
val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler")
bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null)
val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean]
if (!installed) {
Expand Down
11 changes: 5 additions & 6 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1968,7 +1968,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
for (className <- listenerClassNames) {
// Use reflection to find the right constructor
val constructors = {
val listenerClass = Class.forName(className)
val listenerClass = Utils.classForName(className)
listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
}
val constructorTakingSparkConf = constructors.find { c =>
Expand Down Expand Up @@ -2503,7 +2503,7 @@ object SparkContext extends Logging {
"\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.")
}
val scheduler = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
} catch {
Expand All @@ -2515,7 +2515,7 @@ object SparkContext extends Logging {
}
val backend = try {
val clazz =
Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
Expand All @@ -2528,8 +2528,7 @@ object SparkContext extends Logging {

case "yarn-client" =>
val scheduler = try {
val clazz =
Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler")
val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]

Expand All @@ -2541,7 +2540,7 @@ object SparkContext extends Logging {

val backend = try {
val clazz =
Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object SparkEnv extends Logging {

// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String): T = {
val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader)
val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
Expand Down
18 changes: 2 additions & 16 deletions core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}

import org.apache.spark.Logging
import org.apache.spark.api.r.SerDe._
import org.apache.spark.util.Utils

/**
* Handler for RBackend
Expand Down Expand Up @@ -88,21 +89,6 @@ private[r] class RBackendHandler(server: RBackend)
ctx.close()
}

// Looks up a class given a class name. This function first checks the
// current class loader and if a class is not found, it looks up the class
// in the context class loader. Address [SPARK-5185]
def getStaticClass(objId: String): Class[_] = {
try {
val clsCurrent = Class.forName(objId)
clsCurrent
} catch {
// Use contextLoader if we can't find the JAR in the system class loader
case e: ClassNotFoundException =>
val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader)
clsContext
}
}

def handleMethodCall(
isStatic: Boolean,
objId: String,
Expand All @@ -113,7 +99,7 @@ private[r] class RBackendHandler(server: RBackend)
var obj: Object = null
try {
val cls = if (isStatic) {
getStaticClass(objId)
Utils.classForName(objId)
} else {
JVMObjectTracker.get(objId) match {
case None => throw new IllegalArgumentException("Object not found " + objId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.util.Utils

private[spark] class BroadcastManager(
val isDriver: Boolean,
Expand All @@ -42,7 +43,7 @@ private[spark] class BroadcastManager(
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")

broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class SparkHadoopUtil extends Logging {

private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
val statisticsDataClass =
Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
statisticsDataClass.getDeclaredMethod(methodName)
}

Expand Down Expand Up @@ -356,7 +356,7 @@ object SparkHadoopUtil {
System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
.newInstance()
.asInstanceOf[SparkHadoopUtil]
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ object SparkSubmit {
var mainClass: Class[_] = null

try {
mainClass = Class.forName(childMainClass, true, loader)
mainClass = Utils.classForName(childMainClass)
} catch {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
System.setSecurityManager(sm)

try {
Class.forName(mainClass).getMethod("main", classOf[Array[String]])
Utils.classForName(mainClass).getMethod("main", classOf[Array[String]])
.invoke(null, Array(HELP))
} catch {
case e: InvocationTargetException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ object HistoryServer extends Logging {

val providerName = conf.getOption("spark.history.provider")
.getOrElse(classOf[FsHistoryProvider].getName())
val provider = Class.forName(providerName)
val provider = Utils.classForName(providerName)
.getConstructor(classOf[SparkConf])
.newInstance(conf)
.asInstanceOf[ApplicationHistoryProvider]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private[master] class Master(
new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
.newInstance(conf, SerializationExtension(actorSystem))
.asInstanceOf[StandaloneRecoveryModeFactory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage {
*/
def fromJson(json: String): SubmitRestProtocolMessage = {
val className = parseAction(json)
val clazz = Class.forName(packagePrefix + "." + className)
val clazz = Utils.classForName(packagePrefix + "." + className)
.asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
fromJson(json, clazz)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object DriverWrapper {
Thread.currentThread.setContextClassLoader(loader)

// Delegate to supplied main class
val clazz = Class.forName(mainClass, true, loader)
val clazz = Utils.classForName(mainClass)
val mainMethod = clazz.getMethod("main", classOf[Array[String]])
mainMethod.invoke(null, extraArgs.toArray[String])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val ibmVendor = System.getProperty("java.vendor").contains("IBM")
var totalMb = 0
try {
// scalastyle:off classforname
val bean = ManagementFactory.getOperatingSystemMXBean()
if (ibmVendor) {
val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
Expand All @@ -159,6 +160,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
}
// scalastyle:on classforname
} catch {
case e: Exception => {
totalMb = 2*1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ private[spark] class Executor(
logInfo("Using REPL class URI: " + classUri)
try {
val _userClassPathFirst: java.lang.Boolean = userClassPathFirst
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[SparkConf], classOf[String],
classOf[ClassLoader], classOf[Boolean])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ private[spark] object CompressionCodec {
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codec = try {
val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
.getConstructor(classOf[SparkConf])
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
} catch {
case e: ClassNotFoundException => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}

import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.{Logging, SparkEnv, TaskContext}
import org.apache.spark.util.{Utils => SparkUtils}

private[spark]
trait SparkHadoopMapRedUtil {
Expand Down Expand Up @@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil {

private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
Class.forName(first)
SparkUtils.classForName(first)
} catch {
case e: ClassNotFoundException =>
Class.forName(second)
SparkUtils.classForName(second)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
import org.apache.spark.util.Utils

private[spark]
trait SparkHadoopMapReduceUtil {
Expand All @@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil {
isMap: Boolean,
taskId: Int,
attemptId: Int): TaskAttemptID = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
// First, attempt to use the old-style constructor that takes a boolean isMap
// (not available in YARN)
Expand All @@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil {
} catch {
case exc: NoSuchMethodException => {
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType")
.asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
taskTypeClass, if (isMap) "MAP" else "REDUCE")
Expand All @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil {

private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
Class.forName(first)
Utils.classForName(first)
} catch {
case e: ClassNotFoundException =>
Class.forName(second)
Utils.classForName(second)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.metrics
import java.util.Properties
import java.util.concurrent.TimeUnit

import org.apache.spark.util.Utils

import scala.collection.mutable

import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
Expand Down Expand Up @@ -166,7 +168,7 @@ private[spark] class MetricsSystem private (
sourceConfigs.foreach { kv =>
val classPath = kv._2.getProperty("class")
try {
val source = Class.forName(classPath).newInstance()
val source = Utils.classForName(classPath).newInstance()
registerSource(source.asInstanceOf[Source])
} catch {
case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e)
Expand All @@ -182,7 +184,7 @@ private[spark] class MetricsSystem private (
val classPath = kv._2.getProperty("class")
if (null != classPath) {
try {
val sink = Class.forName(classPath)
val sink = Utils.classForName(classPath)
.getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
.newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging {

private[spark] class SplitInfoReflections {
val inputSplitWithLocationInfo =
Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit")
val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo")
val isInMemory = splitLocationInfo.getMethod("isInMemory")
val getLocation = splitLocationInfo.getMethod("getLocation")
}
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ private[spark] object RpcEnv {
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
val rpcEnvName = conf.get("spark.rpc", "akka")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader).
newInstance().asInstanceOf[RpcEnvFactory]
Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}

def create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
extends DeserializationStream {

private val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass): Class[_] =
override def resolveClass(desc: ObjectStreamClass): Class[_] = {
// scalastyle:off classforname
Class.forName(desc.getName, false, loader)
// scalastyle:on classforname
}
}

def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())

try {
// scalastyle:off classforname
// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
// Register classes given through spark.kryo.classesToRegister.
Expand All @@ -111,6 +112,7 @@ class KryoSerializer(conf: SparkConf)
userRegistrator
.map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
.foreach { reg => reg.registerClasses(kryo) }
// scalastyle:on classforname
} catch {
case e: Exception =>
throw new SparkException(s"Failed to register classes with Kryo", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging {

/** ObjectStreamClass$ClassDataSlot.desc field */
val DescField: Field = {
// scalastyle:off classforname
val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
// scalastyle:on classforname
f.setAccessible(true)
f
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
.getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME)

try {
val instance = Class.forName(clsName)
val instance = Utils.classForName(clsName)
.newInstance()
.asInstanceOf[ExternalBlockManager]
instance.init(blockManager, executorId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,12 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM
if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
&& argTypes(0).toString.startsWith("L") // is it an object?
&& argTypes(0).getInternalName == myName) {
// scalastyle:off classforname
output += Class.forName(
owner.replace('/', '.'),
false,
Thread.currentThread.getContextClassLoader)
// scalastyle:on classforname
}
}
}
Expand Down
Loading