Skip to content

Commit

Permalink
Fix remove from Properties + add tests
Browse files Browse the repository at this point in the history
Because java.util.Properties' remove method takes in an Any
instead of a String, there were some issues with matching the
key's hashCode, so removing was not successful in unit tests.

Instead, this commit fixes it by manually filtering out the keys
and adding them to the child thread's properties.
  • Loading branch information
Andrew Or committed Sep 11, 2015
1 parent 8ceae42 commit 3ec715c
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 3 deletions.
15 changes: 12 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicIntege
import java.util.UUID.randomUUID

import scala.collection.JavaConverters._
import scala.collection.JavaConversions._
import scala.collection.{Map, Set}
import scala.collection.generic.Growable
import scala.collection.mutable.{HashMap, HashSet}
Expand Down Expand Up @@ -349,9 +350,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// Thread Local variable that can be used by users to pass information down the stack
private val localProperties = new InheritableThreadLocal[Properties] {
override protected def childValue(parent: Properties): Properties = {
val p = new Properties(parent)
nonInheritedLocalProperties.foreach(p.remove)
p
if (nonInheritedLocalProperties.nonEmpty) {
// If there are properties that should not be inherited, filter them out
val p = new Properties
val filtered = parent.filter { case (k, _) =>
!nonInheritedLocalProperties.contains(k)
}
p.putAll(filtered)
p
} else {
new Properties(parent)
}
}
override protected def initialValue(): Properties = new Properties()
}
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/org/apache/spark/ThreadingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
val threads = (1 to 5).map { i =>
new Thread() {
override def run() {
// TODO: these assertion failures don't actually fail the test...
sc.setLocalProperty("test", i.toString)
assert(sc.getLocalProperty("test") === i.toString)
sem.release()
Expand All @@ -175,6 +176,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
val threads = (1 to 5).map { i =>
new Thread() {
override def run() {
// TODO: these assertion failures don't actually fail the test...
assert(sc.getLocalProperty("test") === "parent")
sc.setLocalProperty("test", i.toString)
assert(sc.getLocalProperty("test") === i.toString)
Expand All @@ -190,6 +192,30 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
assert(sc.getLocalProperty("Foo") === null)
}

test("inheritance exclusions (SPARK-10548)") {
sc = new SparkContext("local", "test")
sc.nonInheritedLocalProperties.add("do-not-inherit-me")
sc.setLocalProperty("do-inherit-me", "parent")
sc.setLocalProperty("do-not-inherit-me", "parent")
var throwable: Option[Throwable] = None
val threads = (1 to 5).map { i =>
new Thread() {
override def run() {
// only the ones we intend to inherit will be passed to the children
try {
assert(sc.getLocalProperty("do-inherit-me") === "parent")
assert(sc.getLocalProperty("do-not-inherit-me") === null)
} catch {
case t: Throwable => throwable = Some(t)
}
}
}
}
threads.foreach(_.start())
threads.foreach(_.join())
throwable.foreach { t => throw t }
}

test("mutations to local properties should not affect submitted jobs (SPARK-6629)") {
val jobStarted = new Semaphore(0)
val jobEnded = new Semaphore(0)
Expand All @@ -210,6 +236,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
// Create a new thread which will inherit the current thread's properties
val thread = new Thread() {
override def run(): Unit = {
// TODO: these assertion failures don't actually fail the test...
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId")
// Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.test.SharedSQLContext

class SQLExecutionSuite extends SharedSQLContext {
import testImplicits._

test("query execution IDs are not inherited across threads") {
sparkContext.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "123")
sparkContext.setLocalProperty("do-inherit-me", "some-value")
var throwable: Option[Throwable] = None
val thread = new Thread {
override def run(): Unit = {
try {
assert(sparkContext.getLocalProperty("do-inherit-me") === "some-value")
assert(sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) === null)
} catch {
case t: Throwable =>
throwable = Some(t)
}
}
}
thread.start()
thread.join()
throwable.foreach { t => throw t }
}

// This is the end-to-end version of the previous test.
test("parallel query execution (SPARK-10548)") {
(1 to 5).foreach { i =>
// Scala's parallel collections spawns new threads as children of the existing threads.
// We need to run this multiple times to ensure new threads are spawned. Without the fix
// for SPARK-10548, this usually fails on the second try.
val df = sparkContext.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b")
(1 to 10).par.foreach { _ => df.count() }
}
}
}

0 comments on commit 3ec715c

Please sign in to comment.