Skip to content

Commit

Permalink
Implement DataFrameQueryContext in Spark Connect
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 29, 2024
1 parent 47c55f4 commit 4c0d755
Show file tree
Hide file tree
Showing 14 changed files with 422 additions and 202 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ message Expression {
google.protobuf.Any extension = 999;
}

// (Optional) Keep the information of the origin for this expression such as stacktrace.
Origin origin = 18;

// Expression for the OVER clause or WINDOW clause.
message Window {
Expand Down Expand Up @@ -405,3 +407,18 @@ message NamedArgumentExpression {
// (Required) The value expression of the named argument.
Expression value = 2;
}

message Origin {
// (Required) Indicate the origin type.
oneof function {
PythonOrigin python_origin = 1;
}
}

message PythonOrigin {
// (Required) Name of the origin, for example, the name of the function
string fragment = 1;

// (Required) Callsite to show to end users, for example, stacktrace.
string call_site = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
Expand All @@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
Expand Down Expand Up @@ -1471,7 +1472,20 @@ class SparkConnectPlanner(
* Catalyst expression
*/
@DeveloperApi
def transformExpression(exp: proto.Expression): Expression = {
def transformExpression(exp: proto.Expression): Expression = if (exp.hasOrigin) {
try {
PySparkCurrentOrigin.set(
exp.getOrigin.getPythonOrigin.getFragment,
exp.getOrigin.getPythonOrigin.getCallSite)
withOrigin { doTransformExpression(exp) }
} finally {
PySparkCurrentOrigin.clear()
}
} else {
doTransformExpression(exp)
}

private def doTransformExpression(exp: proto.Expression): Expression = {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods

import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.{QueryContextType, SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.internal.{Logging, MDC}
Expand Down Expand Up @@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
}
for (queryCtx <- sparkThrowable.getQueryContext) {
sparkThrowableBuilder.addQueryContexts(
FetchErrorDetailsResponse.QueryContext
.newBuilder()
val builder = FetchErrorDetailsResponse.QueryContext
.newBuilder()
val context = if (queryCtx.contextType() == QueryContextType.SQL) {
builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
.setObjectType(queryCtx.objectType())
.setObjectName(queryCtx.objectName())
.setStartIndex(queryCtx.startIndex())
.setStopIndex(queryCtx.stopIndex())
.setFragment(queryCtx.fragment())
.build())
.setSummary(queryCtx.summary())
.build()
} else {
builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
.setFragment(queryCtx.fragment())
.setCallSite(queryCtx.callSite())
.setSummary(queryCtx.summary())
.build()
}
sparkThrowableBuilder.addQueryContexts(context)
}
if (sparkThrowable.getSqlState != null) {
sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)
Expand Down
51 changes: 37 additions & 14 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def getQueryContext(self) -> List[BaseQueryContext]:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
return [QueryContext(q) for q in self._origin.getQueryContext()]
contexts: List[BaseQueryContext] = []
for q in self._origin.getQueryContext():
if q.contextType().toString() == "SQL":
contexts.append(SQLQueryContext(q))
else:
contexts.append(DataFrameQueryContext(q))

return contexts
else:
return []

Expand Down Expand Up @@ -379,17 +386,12 @@ class UnknownException(CapturedException, BaseUnknownException):
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.contextType().toString()
assert context_type in ("SQL", "DataFrame")
if context_type == "DataFrame":
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.objectType())
Expand All @@ -409,13 +411,34 @@ def fragment(self) -> str:
def callSite(self) -> str:
return str(self._q.callSite())

def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkFragment())
def summary(self) -> str:
return str(self._q.summary())


class DataFrameQueryContext(BaseQueryContext):
def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
return str(self._q.objectType())

def objectName(self) -> str:
return str(self._q.objectName())

def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkCallSite())
def startIndex(self) -> int:
return int(self._q.startIndex())

def stopIndex(self) -> int:
return int(self._q.stopIndex())

def fragment(self) -> str:
return str(self._q.fragment())

def callSite(self) -> str:
return str(self._q.callSite())

def summary(self) -> str:
return str(self._q.summary())
83 changes: 75 additions & 8 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def convert_exception(
)
query_contexts = []
for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
query_contexts.append(QueryContext(query_context))
if query_context.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL:
query_contexts.append(SQLQueryContext(query_context))
else:
query_contexts.append(DataFrameQueryContext(query_context))

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(
Expand Down Expand Up @@ -430,17 +433,12 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.context_type

if int(context_type) == QueryContextType.DataFrame.value:
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.object_type)
Expand All @@ -457,6 +455,75 @@ def stopIndex(self) -> int:
def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "SQLQueryContext", "methodName": "callSite"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def summary(self) -> str:
return str(self._q.summary)


class DataFrameQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectType"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def objectName(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectName"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def startIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "startIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def stopIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
return str(self._q.call_site)

Expand Down
Loading

0 comments on commit 4c0d755

Please sign in to comment.