Skip to content

Commit

Permalink
[SPARK-45022][SQL] Provide context for dataset API errors
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Sep 14, 2023
1 parent 8e3e600 commit 7adf30e
Show file tree
Hide file tree
Showing 54 changed files with 748 additions and 296 deletions.
12 changes: 12 additions & 0 deletions common/utils/src/main/java/org/apache/spark/QueryContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
*/
@Evolving
public interface QueryContext {
// The type of this query context.
QueryContextType contextType();

// The object type of the query which throws the exception.
// If the exception is directly from the main query, it should be an empty string.
// Otherwise, it should be the exact object type in upper case. For example, a "VIEW".
Expand All @@ -45,4 +48,13 @@ public interface QueryContext {

// The corresponding fragment of the query which throws the exception.
String fragment();

// The Spark code (API) that caused throwing the exception.
String code();

// The user code (call site of the API) that caused throwing the exception.
String callSite();

// Summary of the exception cause.
String summary();
}
31 changes: 31 additions & 0 deletions common/utils/src/main/java/org/apache/spark/QueryContextType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark;

import org.apache.spark.annotation.Evolving;

/**
* The type of {@link QueryContext}.
*
* @since 3.5.0
*/
@Evolving
public enum QueryContextType {
SQL,
Dataset
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,19 @@ private[spark] object SparkThrowableHelper {
g.writeArrayFieldStart("queryContext")
e.getQueryContext.foreach { c =>
g.writeStartObject()
g.writeStringField("objectType", c.objectType())
g.writeStringField("objectName", c.objectName())
val startIndex = c.startIndex() + 1
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
val stopIndex = c.stopIndex() + 1
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
g.writeStringField("fragment", c.fragment())
c.contextType() match {
case QueryContextType.SQL =>
g.writeStringField("objectType", c.objectType())
g.writeStringField("objectName", c.objectName())
val startIndex = c.startIndex() + 1
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
val stopIndex = c.stopIndex() + 1
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
g.writeStringField("fragment", c.fragment())
case QueryContextType.Dataset =>
g.writeStringField("code", c.code())
g.writeStringField("callSite", c.callSite())
}
g.writeEndObject()
}
g.writeEndArray()
Expand Down
61 changes: 44 additions & 17 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ abstract class SparkFunSuite
sqlState: Option[String] = None,
parameters: Map[String, String] = Map.empty,
matchPVals: Boolean = false,
queryContext: Array[QueryContext] = Array.empty): Unit = {
queryContext: Array[ExpectedContext] = Array.empty): Unit = {
assert(exception.getErrorClass === errorClass)
sqlState.foreach(state => assert(exception.getSqlState === state))
val expectedParameters = exception.getMessageParameters.asScala
Expand All @@ -364,16 +364,25 @@ abstract class SparkFunSuite
val actualQueryContext = exception.getQueryContext()
assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context")
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
assert(actual.objectType() === expected.objectType(),
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName(),
"Invalid objectName of a query context. Actual:" + actual.toString)
assert(actual.startIndex() === expected.startIndex(),
"Invalid startIndex of a query context. Actual:" + actual.toString)
assert(actual.stopIndex() === expected.stopIndex(),
"Invalid stopIndex of a query context. Actual:" + actual.toString)
assert(actual.fragment() === expected.fragment(),
"Invalid fragment of a query context. Actual:" + actual.toString)
assert(actual.contextType() === expected.contextType,
"Invalid contextType of a query context Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
assert(actual.objectType() === expected.objectType,
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
assert(actual.startIndex() === expected.startIndex,
"Invalid startIndex of a query context. Actual:" + actual.toString)
assert(actual.stopIndex() === expected.stopIndex,
"Invalid stopIndex of a query context. Actual:" + actual.toString)
assert(actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.Dataset) {
assert(actual.code() === expected.code,
"Invalid code of a query context. Actual:" + actual.toString)
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
}
}
}

Expand All @@ -389,29 +398,29 @@ abstract class SparkFunSuite
errorClass: String,
sqlState: String,
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context))

protected def checkError(
exception: SparkThrowable,
errorClass: String,
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, None, parameters, false, Array(context))

protected def checkError(
exception: SparkThrowable,
errorClass: String,
sqlState: String,
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, None, Map.empty, false, Array(context))

protected def checkError(
exception: SparkThrowable,
errorClass: String,
sqlState: Option[String],
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, sqlState, parameters,
false, Array(context))

Expand All @@ -426,7 +435,7 @@ abstract class SparkFunSuite
errorClass: String,
sqlState: Option[String],
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, sqlState, parameters,
matchPVals = true, Array(context))

Expand All @@ -453,16 +462,34 @@ abstract class SparkFunSuite
parameters = Map("relationName" -> tableName))

case class ExpectedContext(
contextType: QueryContextType,
objectType: String,
objectName: String,
startIndex: Int,
stopIndex: Int,
fragment: String) extends QueryContext
fragment: String,
code: String,
callSitePattern: String
)

object ExpectedContext {
def apply(fragment: String, start: Int, stop: Int): ExpectedContext = {
ExpectedContext("", "", start, stop, fragment)
}

def apply(
objectType: String,
objectName: String,
startIndex: Int,
stopIndex: Int,
fragment: String): ExpectedContext = {
new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex,
fragment, "", "")
}

def apply(code: String, callSitePattern: String): ExpectedContext = {
new ExpectedContext(QueryContextType.Dataset, "", "", -1, -1, "", code, callSitePattern)
}
}

class LogAppender(msg: String = "", maxEvents: Int = 1000)
Expand Down
53 changes: 53 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,15 @@ class SparkThrowableSuite extends SparkFunSuite {
test("Get message in the specified format") {
import ErrorMessageFormat._
class TestQueryContext extends QueryContext {
override val contextType = QueryContextType.SQL
override val objectName = "v1"
override val objectType = "VIEW"
override val startIndex = 2
override val stopIndex = -1
override val fragment = "1 / 0"
override def code: String = throw new UnsupportedOperationException
override def callSite: String = throw new UnsupportedOperationException
override val summary = ""
}
val e = new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO",
Expand Down Expand Up @@ -532,6 +536,55 @@ class SparkThrowableSuite extends SparkFunSuite {
| "message" : "Test message"
| }
|}""".stripMargin)

class TestQueryContext2 extends QueryContext {
override val contextType = QueryContextType.Dataset
override def objectName: String = throw new UnsupportedOperationException
override def objectType: String = throw new UnsupportedOperationException
override def startIndex: Int = throw new UnsupportedOperationException
override def stopIndex: Int = throw new UnsupportedOperationException
override def fragment: String = throw new UnsupportedOperationException
override val code: String = "div"
override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)"
override val summary = ""
}
val e4 = new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO",
messageParameters = Map("config" -> "CONFIG"),
context = Array(new TestQueryContext2),
summary = "Query summary")

assert(SparkThrowableHelper.getMessage(e4, PRETTY) ===
"[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " +
"and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." +
"\nQuery summary")
// scalastyle:off line.size.limit
assert(SparkThrowableHelper.getMessage(e4, MINIMAL) ===
"""{
| "errorClass" : "DIVIDE_BY_ZERO",
| "sqlState" : "22012",
| "messageParameters" : {
| "config" : "CONFIG"
| },
| "queryContext" : [ {
| "code" : "div",
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
| } ]
|}""".stripMargin)
assert(SparkThrowableHelper.getMessage(e4, STANDARD) ===
"""{
| "errorClass" : "DIVIDE_BY_ZERO",
| "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set <config> to \"false\" to bypass this error.",
| "sqlState" : "22012",
| "messageParameters" : {
| "config" : "CONFIG"
| },
| "queryContext" : [ {
| "code" : "div",
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
| } ]
|}""".stripMargin)
// scalastyle:on line.size.limit
}

test("overwrite error classes") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl
import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin}
import org.apache.spark.sql.catalyst.util.SparkParserUtils
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.internal.SqlApiConf
Expand Down Expand Up @@ -229,7 +229,7 @@ class ParseException(
val builder = new StringBuilder
builder ++= "\n" ++= message
start match {
case Origin(Some(l), Some(p), _, _, _, _, _) =>
case Origin(Some(l), Some(p), _, _, _, _, _, _) =>
builder ++= s"(line $l, pos $p)\n"
command.foreach { cmd =>
val (above, below) = cmd.split("\n").splitAt(l)
Expand Down Expand Up @@ -262,8 +262,7 @@ class ParseException(

object ParseException {
def getQueryContext(): Array[QueryContext] = {
val context = CurrentOrigin.get.context
if (context.isValid) Array(context) else Array.empty
Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.trees

import org.apache.spark.QueryContext
import org.apache.spark.{QueryContext, QueryContextType}

/** The class represents error context of a SQL query. */
case class SQLQueryContext(
Expand All @@ -28,19 +28,20 @@ case class SQLQueryContext(
sqlText: Option[String],
originObjectType: Option[String],
originObjectName: Option[String]) extends QueryContext {
override val contextType = QueryContextType.SQL

override val objectType = originObjectType.getOrElse("")
override val objectName = originObjectName.getOrElse("")
override val startIndex = originStartIndex.getOrElse(-1)
override val stopIndex = originStopIndex.getOrElse(-1)
val objectType = originObjectType.getOrElse("")
val objectName = originObjectName.getOrElse("")
val startIndex = originStartIndex.getOrElse(-1)
val stopIndex = originStopIndex.getOrElse(-1)

/**
* The SQL query context of current node. For example:
* == SQL of VIEW v1(line 1, position 25) ==
* SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
* ^^^^^^^^^^^^^^^
*/
lazy val summary: String = {
override lazy val summary: String = {
// If the query context is missing or incorrect, simply return an empty string.
if (!isValid) {
""
Expand Down Expand Up @@ -116,7 +117,7 @@ case class SQLQueryContext(
}

/** Gets the textual fragment of a SQL query. */
override lazy val fragment: String = {
lazy val fragment: String = {
if (!isValid) {
""
} else {
Expand All @@ -128,6 +129,47 @@ case class SQLQueryContext(
sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined &&
originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length &&
originStartIndex.get <= originStopIndex.get
}

override def code: String = throw new UnsupportedOperationException
override def callSite: String = throw new UnsupportedOperationException
}

case class DatasetQueryContext(
override val code: String,
override val callSite: String) extends QueryContext {
override val contextType = QueryContextType.Dataset

override def objectType: String = throw new UnsupportedOperationException
override def objectName: String = throw new UnsupportedOperationException
override def startIndex: Int = throw new UnsupportedOperationException
override def stopIndex: Int = throw new UnsupportedOperationException
override def fragment: String = throw new UnsupportedOperationException

override lazy val summary: String = {
val builder = new StringBuilder
builder ++= "== Dataset ==\n"
builder ++= "\""

builder ++= code
builder ++= "\""
builder ++= " was called from "
builder ++= callSite
builder += '\n'
builder.result()
}
}

object DatasetQueryContext {
def apply(elements: Array[StackTraceElement]): DatasetQueryContext = {
val methodName = elements(0).getMethodName
val code = if (methodName.length > 1 && methodName(0) == '$') {
methodName.substring(1)
} else {
methodName
}
val callSite = elements(1).toString

DatasetQueryContext(code, callSite)
}
}
Loading

0 comments on commit 7adf30e

Please sign in to comment.