Skip to content

Commit

Permalink
[SPARK-31144][SQL] Wrap Error with QueryExecutionException to notify …
Browse files Browse the repository at this point in the history
…QueryExecutionListener

### What changes were proposed in this pull request?

This PR manually reverts changes in #25292 and then wraps java.lang.Error with `QueryExecutionException` to notify `QueryExecutionListener` to send it to `QueryExecutionListener.onFailure` which only accepts `Exception`.

The bug fix PR for 2.4 is #27904. It needs a separate PR because the touched codes were changed a lot.

### Why are the changes needed?

Avoid API changes and fix a bug.

### Does this PR introduce any user-facing change?

Yes. Reverting an API change happening in 3.0. QueryExecutionListener APIs will be the same as 2.4.

### How was this patch tested?

The new added test.

Closes #27907 from zsxwing/SPARK-31144.

Authored-by: Shixiong Zhu <zsxwing@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
zsxwing authored and dongjoon-hyun committed Mar 13, 2020
1 parent 2a4fed0 commit 1ddf44d
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 37 deletions.
4 changes: 0 additions & 4 deletions project/MimaExcludes.scala
Expand Up @@ -426,10 +426,6 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime$"),

// [SPARK-28556][SQL] QueryExecutionListener should also notify Error
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.util.QueryExecutionListener.onFailure"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onFailure"),

// [SPARK-25382][SQL][PYSPARK] Remove ImageSchema.readImages in 3.0
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.image.ImageSchema.readImages"),

Expand Down
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException}
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.util.{ListenerBus, Utils}
Expand Down Expand Up @@ -55,12 +55,13 @@ trait QueryExecutionListener {
* @param funcName the name of the action that triggered this query.
* @param qe the QueryExecution object that carries detail information like logical plan,
* physical plan, etc.
* @param error the error that failed this query.
*
* @param exception the exception that failed this query. If `java.lang.Error` is thrown during
* execution, it will be wrapped with an `Exception` and it can be accessed by
* `exception.getCause`.
* @note This can be invoked by multiple different threads.
*/
@DeveloperApi
def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit
def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit
}


Expand Down Expand Up @@ -140,7 +141,14 @@ private[sql] class ExecutionListenerBus(session: SparkSession)
val funcName = event.executionName.get
event.executionFailure match {
case Some(ex) =>
listener.onFailure(funcName, event.qe, ex)
val exception = ex match {
case e: Exception => e
case other: Throwable =>
val message = "Hit an error when executing a query" +
(if (other.getMessage == null) "" else s": ${other.getMessage}")
new QueryExecutionException(message, other)
}
listener.onFailure(funcName, event.qe, exception)
case _ =>
listener.onSuccess(funcName, event.qe, event.duration)
}
Expand Down
Expand Up @@ -71,7 +71,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
plan = qe.analyzed

}
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
spark.listenerManager.register(listener)

Expand Down
Expand Up @@ -135,7 +135,7 @@ class SessionStateSuite extends SparkFunSuite {
test("fork new session and inherit listener manager") {
class CommandCollector extends QueryExecutionListener {
val commands: ArrayBuffer[String] = ArrayBuffer.empty[String]
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable) : Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, ex: Exception) : Unit = {}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
commands += funcName
}
Expand Down
Expand Up @@ -28,7 +28,7 @@ class TestQueryExecutionListener extends QueryExecutionListener {
OnSuccessCall.isOnSuccessCalled.set(true)
}

override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = { }
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { }
}

/**
Expand Down
Expand Up @@ -341,7 +341,7 @@ class UDFSuite extends QueryTest with SharedSparkSession {
withTempPath { path =>
var numTotalCachedHit = 0
val listener = new QueryExecutionListener {
override def onFailure(f: String, qe: QueryExecution, e: Throwable): Unit = {}
override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
qe.withCachedData match {
Expand Down
Expand Up @@ -141,7 +141,7 @@ class DataSourceV2DataFrameSuite
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
plan = qe.analyzed
}
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}

try {
Expand Down
Expand Up @@ -157,13 +157,13 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
Seq("parquet", classOf[ParquetDataSourceV2].getCanonicalName).foreach { format =>
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) {
val commands = ArrayBuffer.empty[(String, LogicalPlan)]
val errors = ArrayBuffer.empty[(String, Throwable)]
val exceptions = ArrayBuffer.empty[(String, Exception)]
val listener = new QueryExecutionListener {
override def onFailure(
funcName: String,
qe: QueryExecution,
error: Throwable): Unit = {
errors += funcName -> error
exception: Exception): Unit = {
exceptions += funcName -> exception
}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
Expand Down
Expand Up @@ -216,7 +216,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
plan = qe.analyzed
}
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}

spark.listenerManager.register(listener)
Expand Down
Expand Up @@ -282,7 +282,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with
plan = qe.analyzed

}
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}

spark.listenerManager.register(listener)
Expand Down
Expand Up @@ -20,15 +20,17 @@ package org.apache.spark.sql.util
import scala.collection.mutable.ArrayBuffer

import org.apache.spark._
import org.apache.spark.sql.{functions, AnalysisException, QueryTest, Row}
import org.apache.spark.sql.{functions, AnalysisException, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType

class DataFrameCallbackSuite extends QueryTest
with SharedSparkSession
Expand All @@ -40,7 +42,7 @@ class DataFrameCallbackSuite extends QueryTest
val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
metrics += ((funcName, qe, duration))
Expand All @@ -67,10 +69,10 @@ class DataFrameCallbackSuite extends QueryTest
}

testQuietly("execute callback functions when a DataFrame action failed") {
val metrics = ArrayBuffer.empty[(String, QueryExecution, Throwable)]
val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)]
val listener = new QueryExecutionListener {
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {
metrics += ((funcName, qe, error))
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
metrics += ((funcName, qe, exception))
}

// Only test failed case here, so no need to implement `onSuccess`
Expand All @@ -96,7 +98,7 @@ class DataFrameCallbackSuite extends QueryTest
val metrics = ArrayBuffer.empty[Long]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
val metric = stripAQEPlan(qe.executedPlan) match {
Expand Down Expand Up @@ -136,7 +138,7 @@ class DataFrameCallbackSuite extends QueryTest
val metrics = ArrayBuffer.empty[Long]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
metrics += qe.executedPlan.longMetric("dataSize").value
Expand Down Expand Up @@ -176,10 +178,10 @@ class DataFrameCallbackSuite extends QueryTest

test("execute callback functions for DataFrameWriter") {
val commands = ArrayBuffer.empty[(String, LogicalPlan)]
val errors = ArrayBuffer.empty[(String, Throwable)]
val exceptions = ArrayBuffer.empty[(String, Exception)]
val listener = new QueryExecutionListener {
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {
errors += funcName -> error
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
exceptions += funcName -> exception
}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
Expand Down Expand Up @@ -225,9 +227,9 @@ class DataFrameCallbackSuite extends QueryTest
spark.range(10).select($"id", $"id").write.insertInto("tab")
}
sparkContext.listenerBus.waitUntilEmpty()
assert(errors.length == 1)
assert(errors.head._1 == "insertInto")
assert(errors.head._2 == e)
assert(exceptions.length == 1)
assert(exceptions.head._1 == "insertInto")
assert(exceptions.head._2 == e)
}
}

Expand All @@ -238,7 +240,7 @@ class DataFrameCallbackSuite extends QueryTest
metricMaps += qe.observedMetrics
}

override def onFailure(funcName: String, qe: QueryExecution, exception: Throwable): Unit = {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
// No-op
}
}
Expand Down Expand Up @@ -278,4 +280,32 @@ class DataFrameCallbackSuite extends QueryTest
spark.listenerManager.unregister(listener)
}
}

testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
var e: Exception = null
val listener = new QueryExecutionListener {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
e = exception
}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {}
}
spark.listenerManager.register(listener)

intercept[Error] {
Dataset.ofRows(spark, ErrorTestCommand("foo")).collect()
}
sparkContext.listenerBus.waitUntilEmpty()
assert(e != null && e.isInstanceOf[QueryExecutionException]
&& e.getCause.isInstanceOf[Error] && e.getCause.getMessage == "foo")
spark.listenerManager.unregister(listener)
}
}

/** A test command that throws `java.lang.Error` during execution. */
case class ErrorTestCommand(foo: String) extends RunnableCommand {

override val output: Seq[Attribute] = Seq(AttributeReference("foo", StringType)())

override def run(sparkSession: SparkSession): Seq[Row] =
throw new java.lang.Error(foo)
}
Expand Up @@ -57,7 +57,7 @@ private class CountingQueryExecutionListener extends QueryExecutionListener {
CALLBACK_COUNT.incrementAndGet()
}

override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
CALLBACK_COUNT.incrementAndGet()
}

Expand Down
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.util.QueryExecutionListener

class DummyQueryExecutionListener extends QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}

class DummyStreamingQueryListener extends StreamingQueryListener {
Expand Down

0 comments on commit 1ddf44d

Please sign in to comment.