Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.impl.StaticLoggerBinder

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -115,8 +116,7 @@ trait Logging {
val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
if (!log4jInitialized && usingLog4j) {
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
val classLoader = this.getClass.getClassLoader
Option(classLoader.getResource(defaultLogProps)) match {
Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
case Some(url) =>
PropertyConfigurator.configure(url)
log.info(s"Using Spark's default log4j profile: $defaultLogProps")
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): MutableURLClassLoader = {
val loader = this.getClass.getClassLoader
val currentLoader = Utils.getContextOrSparkClassLoader

// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
Expand All @@ -301,8 +301,8 @@ private[spark] class Executor(
}.toArray
val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
userClassPathFirst match {
case true => new ChildExecutorURLClassLoader(urls, loader)
case false => new ExecutorURLClassLoader(urls, loader)
case true => new ChildExecutorURLClassLoader(urls, currentLoader)
case false => new ExecutorURLClassLoader(urls, currentLoader)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,13 @@ private[spark] class MesosExecutorBackend
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
val cl = Thread.currentThread.getContextClassLoader
try {
// Work around for SPARK-1480
Thread.currentThread.setContextClassLoader(getClass.getClassLoader)
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties)
} finally {
// Work around for SPARK-1480
Thread.currentThread.setContextClassLoader(cl)
}
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties)
}

override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.matching.Regex

import org.apache.spark.Logging
import org.apache.spark.util.Utils

private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {

Expand All @@ -50,7 +51,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
try {
is = configFile match {
case Some(f) => new FileInputStream(f)
case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF)
case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF)
}

if (is != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ private[spark] object ResultTask {

def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
{
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.{NoSuchElementException, Properties}
import scala.xml.XML

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.Utils

/**
* An interface to build Schedulable tree
Expand Down Expand Up @@ -72,7 +73,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
schedulerAllocFile.map { f =>
new FileInputStream(f)
}.getOrElse {
getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
serializedData, getClass.getClassLoader)
serializedData, Utils.getSparkClassLoader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastropic if we can't
// deserialize the reason.
val loader = Thread.currentThread.getContextClassLoader
val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Throwable => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils

private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
extends SerializationStream {
Expand Down Expand Up @@ -86,7 +87,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
}

def deserializeStream(s: InputStream): DeserializationStream = {
new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader)
new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
}

def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}

import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.util.Utils

/**
* Utilities for launching a web server using Jetty's HTTP Server class
Expand Down Expand Up @@ -124,7 +125,7 @@ private[spark] object JettyUtils extends Logging {
contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false")
val staticHandler = new DefaultServlet
val holder = new ServletHolder(staticHandler)
Option(getClass.getClassLoader.getResource(resourceBase)) match {
Option(Utils.getSparkClassLoader.getResource(resourceBase)) match {
case Some(res) =>
holder.setInitParameter("resourceBase", res.toString)
case None =>
Expand Down
15 changes: 15 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ private[spark] object Utils extends Logging {
}
}

/**
* Get the ClassLoader which loaded Spark.
*/
def getSparkClassLoader = getClass.getClassLoader

/**
* Get the Context ClassLoader on this thread or, if not present, the ClassLoader that
* loaded Spark.
*
* This should be used whenever passing a ClassLoader to Class.ForName or finding the currently
* active loader when setting up ClassLoader delegation chains.
*/
def getContextOrSparkClassLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)

/**
* Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.executor

import java.io.File
import java.net.URLClassLoader

import org.scalatest.FunSuite

import org.apache.spark.TestUtils
import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils}
import org.apache.spark.util.Utils

class ExecutorURLClassLoaderSuite extends FunSuite {

Expand Down Expand Up @@ -63,5 +63,33 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
}
}

test("driver sets context class loader in local mode") {
// Test the case where the driver program sets a context classloader and then runs a job
// in local mode. This is what happens when ./spark-submit is called with "local" as the
// master.
val original = Thread.currentThread().getContextClassLoader

val className = "ClassForDriverTest"
val jar = TestUtils.createJarWithClasses(Seq(className))
val contextLoader = new URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
Thread.currentThread().setContextClassLoader(contextLoader)

val sc = new SparkContext("local", "driverLoaderTest")

try {
sc.makeRDD(1 to 5, 2).mapPartitions { x =>
val loader = Thread.currentThread().getContextClassLoader
Class.forName(className, true, loader).newInstance()
Seq().iterator
}.count()
}
catch {
case e: SparkException if e.getMessage.contains("ClassNotFoundException") =>
fail("Local executor could not find class", e)
case t: Throwable => fail("Unexpected exception ", t)
}

sc.stop()
Thread.currentThread().setContextClassLoader(original)
}
}
7 changes: 4 additions & 3 deletions repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse}
import org.apache.spark.Logging
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.util.Utils

/** The Scala interactive shell. It provides a read-eval-print loop
* around the Interpreter class.
Expand Down Expand Up @@ -130,7 +131,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
def history = in.history

/** The context class loader at the time this object was created */
protected val originalClassLoader = Thread.currentThread.getContextClassLoader
protected val originalClassLoader = Utils.getContextOrSparkClassLoader

// classpath entries added via :cp
var addedClasspath: String = ""
Expand Down Expand Up @@ -177,7 +178,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
override lazy val formatting = new Formatting {
def prompt = SparkILoop.this.prompt
}
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
}

/** Create a new interpreter. */
Expand Down Expand Up @@ -871,7 +872,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}

val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
val m = u.runtimeMirror(getClass.getClassLoader)
val m = u.runtimeMirror(Utils.getSparkClassLoader)
private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
u.TypeTag[T](
m,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst

import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File}

import org.apache.spark.util.{Utils => SparkUtils}

package object util {
/**
* Returns a path to a temporary file that probably does not exist.
Expand Down Expand Up @@ -54,7 +56,7 @@ package object util {
def resourceToString(
resource:String,
encoding: String = "UTF-8",
classLoader: ClassLoader = this.getClass.getClassLoader) = {
classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = {
val inStream = classLoader.getResourceAsStream(resource)
val outStream = new ByteArrayOutputStream
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.runtimeMirror
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar._
import org.apache.spark.util.Utils

private[sql] case object PassThrough extends CompressionScheme {
override val typeId = 0
Expand Down Expand Up @@ -254,7 +255,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
private val dictionary = {
// TODO Can we clean up this mess? Maybe move this to `DataType`?
implicit val classTag = {
val mirror = runtimeMirror(getClass.getClassLoader)
val mirror = runtimeMirror(Utils.getSparkClassLoader)
ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.{Serializer, Kryo}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils

class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
Expand All @@ -44,7 +45,7 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
kryo.setClassLoader(this.getClass.getClassLoader)
kryo.setClassLoader(Utils.getSparkClassLoader)
kryo
}
}
Expand Down