Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
@transient
private val runningSubqueries = new ArrayBuffer[ExecSubqueryExpression]

@transient private val prepareLock = new Object()

/**
* Finds scalar subquery expressions in this plan node and starts evaluating them.
*/
Expand All @@ -293,7 +295,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Blocks the thread until all subqueries finish evaluation and update the results.
*/
protected def waitForSubqueries(): Unit = synchronized {
protected def waitForSubqueries(): Unit = prepareLock.synchronized {
// fill in the result of subqueries
runningSubqueries.foreach { sub =>
sub.updateResult()
Expand All @@ -312,7 +314,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
final def prepare(): Unit = {
// doPrepare() may depend on it's children, we should call prepare() on all the children first.
children.foreach(_.prepare())
synchronized {
prepareLock.synchronized {
if (!prepared) {
prepareSubqueries()
doPrepare()
Expand All @@ -329,7 +331,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* @note `prepare` method has already walked down the tree, so the implementation doesn't have
* to call children's `prepare` methods.
*
* This will only be called once, protected by `this`.
* This will only be called once, protected by [[prepareLock]].
*/
protected def doPrepare(): Unit = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,26 @@

package org.apache.spark.sql.execution

import java.lang.management.ManagementFactory
import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.{SparkEnv, SparkException, SparkUnsupportedOperationException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.expressions.{
Attribute, AttributeReference, Expression, ExprId, Literal}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{DataType, IntegerType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils

class SparkPlanSuite extends SharedSparkSession {

Expand Down Expand Up @@ -168,6 +178,60 @@ class SparkPlanSuite extends SharedSparkSession {
}
}
}

test("SPARK-57041: waitForSubqueries must not hold the plan's monitor " +
"while awaiting subquery results") {
val enteredLatch = new CountDownLatch(1)
val releaseLatch = new CountDownLatch(1)

val subqueryExec = TestSubqueryExec(LocalTableScanExec(Nil, Nil, None))
val subqueryExpr = BlockingSubquery(subqueryExec, ExprId(0), enteredLatch, releaseLatch)
val plan = TestPlanWithSubquery(subqueryExpr)

val executor = ThreadUtils.newDaemonSingleThreadExecutor("test-wait-for-subqueries")
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(executor)

plan.testPrepare()
val futureA = Future { plan.testWaitForSubqueries() }

try {
assert(enteredLatch.await(10, TimeUnit.SECONDS),
"Thread A did not enter updateResult() within 10s")

val threadB = new Thread(() => plan.synchronized {})
threadB.setDaemon(true)
threadB.start()

val bean = ManagementFactory.getThreadMXBean
val deadline = System.currentTimeMillis() + 5000L
var threadBBlocked = false
var waiting = true
while (waiting) {
if (!threadB.isAlive || System.currentTimeMillis() > deadline) {
waiting = false
} else {
val state = Option(bean.getThreadInfo(threadB.getId)).map(_.getThreadState).orNull
if (state == Thread.State.BLOCKED) {
threadBBlocked = true
waiting = false
} else if (state != null) {
Thread.sleep(1)
}
}
}

releaseLatch.countDown()
ThreadUtils.awaitResult(futureA, Duration(10, "seconds"))
threadB.join(5000L)

assert(!threadBBlocked,
"Deadlock: plan.this.synchronized could not be acquired while waitForSubqueries() was " +
"blocking on a subquery future. waitForSubqueries() must not hold the plan's monitor.")
} finally {
releaseLatch.countDown()
executor.shutdown()
}
}
}

case class ColumnarOp(child: SparkPlan) extends UnaryExecNode {
Expand All @@ -179,3 +243,41 @@ case class ColumnarOp(child: SparkPlan) extends UnaryExecNode {
override protected def withNewChildInternal(newChild: SparkPlan): ColumnarOp =
copy(child = newChild)
}

private case class TestSubqueryExec(child: SparkPlan) extends BaseSubqueryExec {
override def name: String = "TestSubqueryExec"
override def children: Seq[SparkPlan] = Seq(child)
override protected def doExecute(): RDD[InternalRow] = child.execute()
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[SparkPlan]): TestSubqueryExec = copy(child = newChildren.head)
}

private case class BlockingSubquery(
plan: BaseSubqueryExec,
exprId: ExprId,
enteredLatch: CountDownLatch,
releaseLatch: CountDownLatch)
extends ExecSubqueryExpression with LeafLike[Expression] {

override def dataType: DataType = IntegerType
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = null
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw new UnsupportedOperationException("test only")
override def withNewPlan(plan: BaseSubqueryExec): ExecSubqueryExpression =
copy(plan = plan)

override def updateResult(): Unit = {
enteredLatch.countDown()
releaseLatch.await(30, TimeUnit.SECONDS)
}
}

private case class TestPlanWithSubquery(subqueryExpr: ExecSubqueryExpression)
extends LeafExecNode {
override def output: Seq[Attribute] = Nil
override protected def doExecute(): RDD[InternalRow] =
throw new UnsupportedOperationException("test only")
def testPrepare(): Unit = prepare()
def testWaitForSubqueries(): Unit = waitForSubqueries()
}