diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 726ae5dd1c219..865a1742b7394 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -54,6 +54,8 @@ message Expression { google.protobuf.Any extension = 999; } + // (Optional) Keep the information of the origin for this expression such as stacktrace. + Origin origin = 18; // Expression for the OVER clause or WINDOW clause. message Window { @@ -405,3 +407,18 @@ message NamedArgumentExpression { // (Required) The value expression of the named argument. Expression value = 2; } + +message Origin { + // (Required) Indicate the origin type. + oneof function { + PythonOrigin python_origin = 1; + } +} + +message PythonOrigin { + // (Required) Name of the origin, for example, the name of the function + string fragment = 1; + + // (Required) Callsite to show to end users, for example, stacktrace. + string call_site = 2; +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a339469e61cdf..8614c5cf0e539 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} -import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} +import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} @@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} @@ -1471,7 +1472,20 @@ class SparkConnectPlanner( * Catalyst expression */ @DeveloperApi - def transformExpression(exp: proto.Expression): Expression = { + def transformExpression(exp: proto.Expression): Expression = if (exp.hasOrigin) { + try { + PySparkCurrentOrigin.set( + exp.getOrigin.getPythonOrigin.getFragment, + exp.getOrigin.getPythonOrigin.getCallSite) + withOrigin { doTransformExpression(exp) } + } finally { + PySparkCurrentOrigin.clear() + } + } else { + doTransformExpression(exp) + } + + private def doTransformExpression(exp: proto.Expression): Expression = { exp.getExprTypeCase match { case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral) case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE => diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 773a97e929736..355048cf30363 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods -import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} +import org.apache.spark.{QueryContextType, SparkEnv, SparkException, SparkThrowable} import org.apache.spark.api.python.PythonException import org.apache.spark.connect.proto.FetchErrorDetailsResponse import org.apache.spark.internal.{Logging, MDC} @@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging { sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass) } for (queryCtx <- sparkThrowable.getQueryContext) { - sparkThrowableBuilder.addQueryContexts( - FetchErrorDetailsResponse.QueryContext - .newBuilder() + val builder = FetchErrorDetailsResponse.QueryContext + .newBuilder() + val context = if (queryCtx.contextType() == QueryContextType.SQL) { + builder + .setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL) .setObjectType(queryCtx.objectType()) .setObjectName(queryCtx.objectName()) .setStartIndex(queryCtx.startIndex()) .setStopIndex(queryCtx.stopIndex()) .setFragment(queryCtx.fragment()) - .build()) + .setSummary(queryCtx.summary()) + .build() + } else { + builder + .setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME) + .setFragment(queryCtx.fragment()) + .setCallSite(queryCtx.callSite()) + .setSummary(queryCtx.summary()) + .build() + } + sparkThrowableBuilder.addQueryContexts(context) } if (sparkThrowable.getSqlState != null) { sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState) diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index 2a30eba3fb22f..b5bb742161c06 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -166,7 +166,14 @@ def getQueryContext(self) -> List[BaseQueryContext]: if self._origin is not None and is_instance_of( gw, self._origin, "org.apache.spark.SparkThrowable" ): - return [QueryContext(q) for q in self._origin.getQueryContext()] + contexts: List[BaseQueryContext] = [] + for q in self._origin.getQueryContext(): + if q.contextType().toString() == "SQL": + contexts.append(SQLQueryContext(q)) + else: + contexts.append(DataFrameQueryContext(q)) + + return contexts else: return [] @@ -379,17 +386,12 @@ class UnknownException(CapturedException, BaseUnknownException): """ -class QueryContext(BaseQueryContext): +class SQLQueryContext(BaseQueryContext): def __init__(self, q: "JavaObject"): self._q = q def contextType(self) -> QueryContextType: - context_type = self._q.contextType().toString() - assert context_type in ("SQL", "DataFrame") - if context_type == "DataFrame": - return QueryContextType.DataFrame - else: - return QueryContextType.SQL + return QueryContextType.SQL def objectType(self) -> str: return str(self._q.objectType()) @@ -409,13 +411,34 @@ def fragment(self) -> str: def callSite(self) -> str: return str(self._q.callSite()) - def pysparkFragment(self) -> Optional[str]: # type: ignore[return] - if self.contextType() == QueryContextType.DataFrame: - return str(self._q.pysparkFragment()) + def summary(self) -> str: + return str(self._q.summary()) + + +class DataFrameQueryContext(BaseQueryContext): + def __init__(self, q: "JavaObject"): + self._q = q + + def contextType(self) -> QueryContextType: + return QueryContextType.DataFrame + + def objectType(self) -> str: + return str(self._q.objectType()) + + def objectName(self) -> str: + return str(self._q.objectName()) - def pysparkCallSite(self) -> Optional[str]: # type: ignore[return] - if self.contextType() == QueryContextType.DataFrame: - return str(self._q.pysparkCallSite()) + def startIndex(self) -> int: + return int(self._q.startIndex()) + + def stopIndex(self) -> int: + return int(self._q.stopIndex()) + + def fragment(self) -> str: + return str(self._q.fragment()) + + def callSite(self) -> str: + return str(self._q.callSite()) def summary(self) -> str: return str(self._q.summary()) diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 0cffe72687539..8a95358f26975 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -91,7 +91,10 @@ def convert_exception( ) query_contexts = [] for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts: - query_contexts.append(QueryContext(query_context)) + if query_context.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL: + query_contexts.append(SQLQueryContext(query_context)) + else: + query_contexts.append(DataFrameQueryContext(query_context)) if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: return ParseException( @@ -430,17 +433,12 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx """ -class QueryContext(BaseQueryContext): +class SQLQueryContext(BaseQueryContext): def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext): self._q = q def contextType(self) -> QueryContextType: - context_type = self._q.context_type - - if int(context_type) == QueryContextType.DataFrame.value: - return QueryContextType.DataFrame - else: - return QueryContextType.SQL + return QueryContextType.SQL def objectType(self) -> str: return str(self._q.object_type) @@ -457,6 +455,75 @@ def stopIndex(self) -> int: def fragment(self) -> str: return str(self._q.fragment) + def callSite(self) -> str: + raise UnsupportedOperationException( + "", + error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", + message_parameters={"className": "SQLQueryContext", "methodName": "callSite"}, + sql_state="0A000", + server_stacktrace=None, + display_server_stacktrace=False, + query_contexts=[], + ) + + def summary(self) -> str: + return str(self._q.summary) + + +class DataFrameQueryContext(BaseQueryContext): + def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext): + self._q = q + + def contextType(self) -> QueryContextType: + return QueryContextType.DataFrame + + def objectType(self) -> str: + raise UnsupportedOperationException( + "", + error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", + message_parameters={"className": "DataFrameQueryContext", "methodName": "objectType"}, + sql_state="0A000", + server_stacktrace=None, + display_server_stacktrace=False, + query_contexts=[], + ) + + def objectName(self) -> str: + raise UnsupportedOperationException( + "", + error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", + message_parameters={"className": "DataFrameQueryContext", "methodName": "objectName"}, + sql_state="0A000", + server_stacktrace=None, + display_server_stacktrace=False, + query_contexts=[], + ) + + def startIndex(self) -> int: + raise UnsupportedOperationException( + "", + error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", + message_parameters={"className": "DataFrameQueryContext", "methodName": "startIndex"}, + sql_state="0A000", + server_stacktrace=None, + display_server_stacktrace=False, + query_contexts=[], + ) + + def stopIndex(self) -> int: + raise UnsupportedOperationException( + "", + error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", + message_parameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"}, + sql_state="0A000", + server_stacktrace=None, + display_server_stacktrace=False, + query_contexts=[], + ) + + def fragment(self) -> str: + return str(self._q.fragment) + def callSite(self) -> str: return str(self._q.call_site) diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py index cddec3319964e..e268ade756d3e 100644 --- a/python/pyspark/errors/utils.py +++ b/python/pyspark/errors/utils.py @@ -19,16 +19,34 @@ import functools import inspect import os -from typing import Any, Callable, Dict, Match, TypeVar, Type, TYPE_CHECKING +import threading +from typing import Any, Callable, Dict, Match, TypeVar, Type, Optional, TYPE_CHECKING from pyspark.errors.error_classes import ERROR_CLASSES_MAP - if TYPE_CHECKING: from pyspark.sql import SparkSession - from py4j.java_gateway import JavaClass T = TypeVar("T") +_current_origin = threading.local() + + +def current_origin() -> Optional[threading.local]: + global _current_origin + + if not hasattr(_current_origin, "fragment"): + _current_origin.fragment = None + if not hasattr(_current_origin, "call_site"): + _current_origin.call_site = None + return _current_origin + + +def set_current_origin(fragment: Optional[str], call_site: Optional[str]) -> None: + global _current_origin + + _current_origin.fragment = fragment + _current_origin.fragment = call_site + class ErrorClassesReader: """ @@ -130,9 +148,7 @@ def get_message_template(self, error_class: str) -> str: return message_template -def _capture_call_site( - spark_session: "SparkSession", pyspark_origin: "JavaClass", fragment: str -) -> None: +def _capture_call_site(spark_session: "SparkSession", depth: int) -> str: """ Capture the call site information including file name, line number, and function name. This function updates the thread-local storage from JVM side (PySparkCurrentOrigin) @@ -142,10 +158,6 @@ def _capture_call_site( ---------- spark_session : SparkSession Current active Spark session. - pyspark_origin : py4j.JavaClass - PySparkCurrentOrigin from current active Spark session. - fragment : str - The name of the PySpark API function being captured. Notes ----- @@ -153,14 +165,11 @@ def _capture_call_site( in the user code that led to the error. """ stack = list(reversed(inspect.stack())) - depth = int( - spark_session.conf.get("spark.sql.stackTracesInDataFrameContext") # type: ignore[arg-type] - ) selected_frames = stack[:depth] call_sites = [f"{frame.filename}:{frame.lineno}" for frame in selected_frames] call_sites_str = "\n".join(call_sites) - pyspark_origin.set(fragment, call_sites_str) + return call_sites_str def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]: @@ -172,19 +181,38 @@ def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: from pyspark.sql import SparkSession + from pyspark.sql.utils import is_remote spark = SparkSession.getActiveSession() if spark is not None and hasattr(func, "__name__"): - assert spark._jvm is not None - pyspark_origin = spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin + if is_remote(): + global current_origin - # Update call site when the function is called - _capture_call_site(spark, pyspark_origin, func.__name__) + # Getting the configuration requires RPC call. Uses the default value for now. + depth = 1 + set_current_origin(func.__name__, _capture_call_site(spark, depth)) - try: - return func(*args, **kwargs) - finally: - pyspark_origin.clear() + try: + return func(*args, **kwargs) + finally: + set_current_origin(None, None) + else: + assert spark._jvm is not None + jvm_pyspark_origin = ( + spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin + ) + depth = int( + spark.conf.get( # type: ignore[arg-type] + "spark.sql.stackTracesInDataFrameContext" + ) + ) + # Update call site when the function is called + jvm_pyspark_origin.set(func.__name__, _capture_call_site(spark, depth)) + + try: + return func(*args, **kwargs) + finally: + jvm_pyspark_origin.clear() else: return func(*args, **kwargs) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 146c517d1b1af..5d34dcd9480c0 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -46,6 +46,7 @@ WithField, DropField, ) +from pyspark.errors.utils import with_origin_to_class if TYPE_CHECKING: @@ -95,6 +96,7 @@ def _unary_op(name: str, self: ParentColumn) -> ParentColumn: return Column(UnresolvedFunction(name, [self._expr])) # type: ignore[list-item] +@with_origin_to_class class Column(ParentColumn): def __new__( cls, diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 8cd386ba03aea..91a4dad55f64c 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -15,7 +15,6 @@ # limitations under the License. # from pyspark.sql.connect.utils import check_dependencies -from pyspark.sql.utils import is_timestamp_ntz_preferred check_dependencies(__name__) @@ -77,6 +76,8 @@ proto_schema_to_pyspark_data_type, ) from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.errors.utils import current_origin +from pyspark.sql.utils import is_timestamp_ntz_preferred if TYPE_CHECKING: from pyspark.sql.connect.client import SparkConnectClient @@ -89,7 +90,16 @@ class Expression: """ def __init__(self) -> None: - pass + origin = current_origin() + fragment = origin.fragment + call_site = origin.call_site + self.origin = None + if fragment is not None and call_site is not None: + self.origin = proto.Origin( + python_origin=proto.PythonOrigin( + fragment=origin.fragment, call_site=origin.call_site + ) + ) def to_plan( # type: ignore[empty-body] self, session: "SparkConnectClient" @@ -162,7 +172,7 @@ def __init__(self, parent: Expression, alias: Sequence[str], metadata: Any): def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": if len(self._alias) == 1: - exp = proto.Expression() + exp = proto.Expression(origin=self.origin) exp.alias.name.append(self._alias[0]) exp.alias.expr.CopyFrom(self._parent.to_plan(session)) @@ -175,7 +185,7 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": error_class="CANNOT_PROVIDE_METADATA", message_parameters={}, ) - exp = proto.Expression() + exp = proto.Expression(origin=self.origin) exp.alias.name.extend(self._alias) exp.alias.expr.CopyFrom(self._parent.to_plan(session)) return exp @@ -407,7 +417,7 @@ def _to_value( def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": """Converts the literal expression to the literal in proto.""" - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) if self._value is None: expr.literal.null.CopyFrom(pyspark_types_to_proto_types(self._dataType)) @@ -483,7 +493,7 @@ def name(self) -> str: def to_plan(self, session: "SparkConnectClient") -> proto.Expression: """Returns the Proto representation of the expression.""" - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier if self._plan_id is not None: expr.unresolved_attribute.plan_id = self._plan_id @@ -512,7 +522,7 @@ def __init__(self, unparsed_target: Optional[str], plan_id: Optional[int] = None self._plan_id = plan_id def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.unresolved_star.SetInParent() if self._unparsed_target is not None: expr.unresolved_star.unparsed_target = self._unparsed_target @@ -546,7 +556,7 @@ def __init__(self, expr: str) -> None: def to_plan(self, session: "SparkConnectClient") -> proto.Expression: """Returns the Proto representation of the SQL expression.""" - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.expression_string.expression = self._expr return expr @@ -572,7 +582,7 @@ def __repr__(self) -> str: ) def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - sort = proto.Expression() + sort = proto.Expression(origin=self.origin) sort.sort_order.child.CopyFrom(self._child.to_plan(session)) if self._ascending: @@ -611,7 +621,7 @@ def __init__( self._is_distinct = is_distinct def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - fun = proto.Expression() + fun = proto.Expression(origin=self.origin) fun.unresolved_function.function_name = self._name if len(self._args) > 0: fun.unresolved_function.arguments.extend([arg.to_plan(session) for arg in self._args]) @@ -708,7 +718,7 @@ def __init__( self._function = function def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.common_inline_user_defined_function.function_name = self._function_name expr.common_inline_user_defined_function.deterministic = self._deterministic if len(self._arguments) > 0: @@ -762,7 +772,7 @@ def __init__( self._valueExpr = valueExpr def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session)) expr.update_fields.field_name = self._fieldName expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session)) @@ -787,7 +797,7 @@ def __init__( self._fieldName = fieldName def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session)) expr.update_fields.field_name = self._fieldName return expr @@ -811,7 +821,7 @@ def __init__( self._extraction = extraction def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session)) expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session)) return expr @@ -831,7 +841,7 @@ def __init__(self, col_name: str, plan_id: Optional[int] = None) -> None: self._plan_id = plan_id def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.unresolved_regex.col_name = self.col_name if self._plan_id is not None: expr.unresolved_regex.plan_id = self._plan_id @@ -858,7 +868,7 @@ def __init__( self._eval_mode = eval_mode def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - fun = proto.Expression() + fun = proto.Expression(origin=self.origin) fun.cast.expr.CopyFrom(self._expr.to_plan(session)) if isinstance(self._data_type, str): fun.cast.type_str = self._data_type @@ -909,7 +919,7 @@ def __init__( self._name_parts = name_parts def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.unresolved_named_lambda_variable.name_parts.extend(self._name_parts) return expr @@ -951,7 +961,7 @@ def __init__( self._arguments = arguments def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.lambda_function.function.CopyFrom(self._function.to_plan(session)) expr.lambda_function.arguments.extend( [arg.to_plan(session).unresolved_named_lambda_variable for arg in self._arguments] @@ -981,7 +991,7 @@ def __init__( self._windowSpec = windowSpec def to_plan(self, session: "SparkConnectClient") -> proto.Expression: - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.window.window_function.CopyFrom(self._windowFunction.to_plan(session)) @@ -1088,7 +1098,7 @@ def __init__(self, name: str, args: Sequence["Expression"]): self._args = args def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.call_function.function_name = self._name if len(self._args) > 0: expr.call_function.arguments.extend([arg.to_plan(session) for arg in self._args]) @@ -1112,7 +1122,7 @@ def __init__(self, key: str, value: Expression): self._value = value def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": - expr = proto.Expression() + expr = proto.Expression(origin=self.origin) expr.named_argument_expression.key = self._key expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session)) return expr diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index e42acbf49a7df..6d222e2af60c8 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xde.\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x64\n\x19named_argument_expression\x18\x11 \x01(\x0b\x32&.spark.connect.NamedArgumentExpressionH\x00R\x17namedArgumentExpression\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\xbb\x02\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStr\x12\x44\n\teval_mode\x18\x04 \x01(\x0e\x32\'.spark.connect.Expression.Cast.EvalModeR\x08\x65valMode"b\n\x08\x45valMode\x12\x19\n\x15\x45VAL_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45VAL_MODE_LEGACY\x10\x01\x12\x12\n\x0e\x45VAL_MODE_ANSI\x10\x02\x12\x11\n\rEVAL_MODE_TRY\x10\x03\x42\x0e\n\x0c\x63\x61st_to_type\x1a\x9b\x0c\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x82\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xe3\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\x81\x01\n\x06Struct\x12\x38\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lementsB\x0e\n\x0cliteral_type\x1a\xba\x01\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12\x31\n\x12is_metadata_column\x18\x03 \x01(\x08H\x01R\x10isMetadataColumn\x88\x01\x01\x42\n\n\x08_plan_idB\x15\n\x13_is_metadata_column\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a|\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x01R\x06planId\x88\x01\x01\x42\x12\n\x10_unparsed_targetB\n\n\x08_plan_id\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"\xec\x02\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdfB\n\n\x08\x66unction"\x9b\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer"\xb8\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\\\n\x17NamedArgumentExpression\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05valueB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x8d/\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x64\n\x19named_argument_expression\x18\x11 \x01(\x0b\x32&.spark.connect.NamedArgumentExpressionH\x00R\x17namedArgumentExpression\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12-\n\x06origin\x18\x12 \x01(\x0b\x32\x15.spark.connect.OriginR\x06origin\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\xbb\x02\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStr\x12\x44\n\teval_mode\x18\x04 \x01(\x0e\x32\'.spark.connect.Expression.Cast.EvalModeR\x08\x65valMode"b\n\x08\x45valMode\x12\x19\n\x15\x45VAL_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45VAL_MODE_LEGACY\x10\x01\x12\x12\n\x0e\x45VAL_MODE_ANSI\x10\x02\x12\x11\n\rEVAL_MODE_TRY\x10\x03\x42\x0e\n\x0c\x63\x61st_to_type\x1a\x9b\x0c\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x82\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xe3\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\x81\x01\n\x06Struct\x12\x38\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lementsB\x0e\n\x0cliteral_type\x1a\xba\x01\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12\x31\n\x12is_metadata_column\x18\x03 \x01(\x08H\x01R\x10isMetadataColumn\x88\x01\x01\x42\n\n\x08_plan_idB\x15\n\x13_is_metadata_column\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a|\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x01R\x06planId\x88\x01\x01\x42\x12\n\x10_unparsed_targetB\n\n\x08_plan_id\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"\xec\x02\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdfB\n\n\x08\x66unction"\x9b\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer"\xb8\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\\\n\x17NamedArgumentExpression\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"X\n\x06Origin\x12\x42\n\rpython_origin\x18\x01 \x01(\x0b\x32\x1b.spark.connect.PythonOriginH\x00R\x0cpythonOriginB\n\n\x08\x66unction"G\n\x0cPythonOrigin\x12\x1a\n\x08\x66ragment\x18\x01 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x02 \x01(\tR\x08\x63\x61llSiteB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -46,67 +46,71 @@ b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 6087 - _EXPRESSION_WINDOW._serialized_start = 1645 - _EXPRESSION_WINDOW._serialized_end = 2428 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2428 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2202 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2347 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2349 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2428 - _EXPRESSION_SORTORDER._serialized_start = 2431 - _EXPRESSION_SORTORDER._serialized_end = 2856 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2661 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2769 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2771 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2856 - _EXPRESSION_CAST._serialized_start = 2859 - _EXPRESSION_CAST._serialized_end = 3174 - _EXPRESSION_CAST_EVALMODE._serialized_start = 3060 - _EXPRESSION_CAST_EVALMODE._serialized_end = 3158 - _EXPRESSION_LITERAL._serialized_start = 3177 - _EXPRESSION_LITERAL._serialized_end = 4740 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4012 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4129 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4131 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4229 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 4232 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 4362 - _EXPRESSION_LITERAL_MAP._serialized_start = 4365 - _EXPRESSION_LITERAL_MAP._serialized_end = 4592 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 4595 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 4724 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4743 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4929 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4932 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5136 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5138 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5188 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5190 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5314 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5316 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5402 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5405 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5537 - _EXPRESSION_UPDATEFIELDS._serialized_start = 5540 - _EXPRESSION_UPDATEFIELDS._serialized_end = 5727 - _EXPRESSION_ALIAS._serialized_start = 5729 - _EXPRESSION_ALIAS._serialized_end = 5849 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5852 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6010 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6012 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6074 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6090 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6454 - _PYTHONUDF._serialized_start = 6457 - _PYTHONUDF._serialized_end = 6612 - _SCALARSCALAUDF._serialized_start = 6615 - _SCALARSCALAUDF._serialized_end = 6799 - _JAVAUDF._serialized_start = 6802 - _JAVAUDF._serialized_end = 6951 - _CALLFUNCTION._serialized_start = 6953 - _CALLFUNCTION._serialized_end = 7061 - _NAMEDARGUMENTEXPRESSION._serialized_start = 7063 - _NAMEDARGUMENTEXPRESSION._serialized_end = 7155 + _EXPRESSION._serialized_end = 6134 + _EXPRESSION_WINDOW._serialized_start = 1692 + _EXPRESSION_WINDOW._serialized_end = 2475 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1982 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2475 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2249 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2394 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2396 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2475 + _EXPRESSION_SORTORDER._serialized_start = 2478 + _EXPRESSION_SORTORDER._serialized_end = 2903 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2708 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2816 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2818 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2903 + _EXPRESSION_CAST._serialized_start = 2906 + _EXPRESSION_CAST._serialized_end = 3221 + _EXPRESSION_CAST_EVALMODE._serialized_start = 3107 + _EXPRESSION_CAST_EVALMODE._serialized_end = 3205 + _EXPRESSION_LITERAL._serialized_start = 3224 + _EXPRESSION_LITERAL._serialized_end = 4787 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4059 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4176 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4178 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4276 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 4279 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 4409 + _EXPRESSION_LITERAL_MAP._serialized_start = 4412 + _EXPRESSION_LITERAL_MAP._serialized_end = 4639 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 4642 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 4771 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4790 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4976 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4979 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5183 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5185 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5235 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5237 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5361 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5363 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5449 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5452 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5584 + _EXPRESSION_UPDATEFIELDS._serialized_start = 5587 + _EXPRESSION_UPDATEFIELDS._serialized_end = 5774 + _EXPRESSION_ALIAS._serialized_start = 5776 + _EXPRESSION_ALIAS._serialized_end = 5896 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5899 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6057 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6059 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6121 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6137 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6501 + _PYTHONUDF._serialized_start = 6504 + _PYTHONUDF._serialized_end = 6659 + _SCALARSCALAUDF._serialized_start = 6662 + _SCALARSCALAUDF._serialized_end = 6846 + _JAVAUDF._serialized_start = 6849 + _JAVAUDF._serialized_end = 6998 + _CALLFUNCTION._serialized_start = 7000 + _CALLFUNCTION._serialized_end = 7108 + _NAMEDARGUMENTEXPRESSION._serialized_start = 7110 + _NAMEDARGUMENTEXPRESSION._serialized_end = 7202 + _ORIGIN._serialized_start = 7204 + _ORIGIN._serialized_end = 7292 + _PYTHONORIGIN._serialized_start = 7294 + _PYTHONORIGIN._serialized_end = 7365 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 183a839da9204..c183dcb38d023 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -1181,6 +1181,7 @@ class Expression(google.protobuf.message.Message): CALL_FUNCTION_FIELD_NUMBER: builtins.int NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int + ORIGIN_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @property @@ -1222,6 +1223,9 @@ class Expression(google.protobuf.message.Message): """This field is used to mark extensions to the protocol. When plugins generate arbitrary relations they can add them here. During the planning the correct resolution is done. """ + @property + def origin(self) -> global___Origin: + """(Optional) Keep the information of the origin for this expression such as stacktrace.""" def __init__( self, *, @@ -1244,6 +1248,7 @@ class Expression(google.protobuf.message.Message): call_function: global___CallFunction | None = ..., named_argument_expression: global___NamedArgumentExpression | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., + origin: global___Origin | None = ..., ) -> None: ... def HasField( self, @@ -1268,6 +1273,8 @@ class Expression(google.protobuf.message.Message): b"literal", "named_argument_expression", b"named_argument_expression", + "origin", + b"origin", "sort_order", b"sort_order", "unresolved_attribute", @@ -1311,6 +1318,8 @@ class Expression(google.protobuf.message.Message): b"literal", "named_argument_expression", b"named_argument_expression", + "origin", + b"origin", "sort_order", b"sort_order", "unresolved_attribute", @@ -1619,3 +1628,54 @@ class NamedArgumentExpression(google.protobuf.message.Message): ) -> None: ... global___NamedArgumentExpression = NamedArgumentExpression + +class Origin(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + PYTHON_ORIGIN_FIELD_NUMBER: builtins.int + @property + def python_origin(self) -> global___PythonOrigin: ... + def __init__( + self, + *, + python_origin: global___PythonOrigin | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "function", b"function", "python_origin", b"python_origin" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "function", b"function", "python_origin", b"python_origin" + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["function", b"function"] + ) -> typing_extensions.Literal["python_origin"] | None: ... + +global___Origin = Origin + +class PythonOrigin(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + FRAGMENT_FIELD_NUMBER: builtins.int + CALL_SITE_FIELD_NUMBER: builtins.int + fragment: builtins.str + """(Required) Name of the origin, for example, the name of the function""" + call_site: builtins.str + """(Required) Callsite to show to end users, for example, stacktrace.""" + def __init__( + self, + *, + fragment: builtins.str = ..., + call_site: builtins.str = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal["call_site", b"call_site", "fragment", b"fragment"], + ) -> None: ... + +global___PythonOrigin = PythonOrigin diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py index 38bcd56439843..59107363571ee 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py @@ -21,10 +21,8 @@ from pyspark.testing.connectutils import ReusedConnectTestCase -class DataFrameParityTests(DataFrameQueryContextTestsMixin, ReusedConnectTestCase): - @unittest.skip("Spark Connect does not support DataFrameQueryContext currently.") - def test_dataframe_query_context(self): - super().test_dataframe_query_context() +class DataFrameQueryContextParityTests(DataFrameQueryContextTestsMixin, ReusedConnectTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py b/python/pyspark/sql/tests/test_dataframe_query_context.py index 42fb0b0e452fa..e1a3e33df8593 100644 --- a/python/pyspark/sql/tests/test_dataframe_query_context.py +++ b/python/pyspark/sql/tests/test_dataframe_query_context.py @@ -41,7 +41,7 @@ def test_dataframe_query_context(self): error_class="DIVIDE_BY_ZERO", message_parameters={"config": '"spark.sql.ansi.enabled"'}, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="divide", + fragment="__truediv__", ) # DataFrameQueryContext with pysparkLoggingInfo - plus @@ -57,7 +57,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="plus", + fragment="__add__", ) # DataFrameQueryContext with pysparkLoggingInfo - minus @@ -73,7 +73,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="minus", + fragment="__sub__", ) # DataFrameQueryContext with pysparkLoggingInfo - multiply @@ -89,7 +89,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="multiply", + fragment="__mul__", ) # DataFrameQueryContext with pysparkLoggingInfo - mod @@ -105,7 +105,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="mod", + fragment="__mod__", ) # DataFrameQueryContext with pysparkLoggingInfo - equalTo @@ -121,7 +121,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="__eq__", + fragment="__eq__", ) # DataFrameQueryContext with pysparkLoggingInfo - lt @@ -137,7 +137,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="lt", + fragment="__lt__", ) # DataFrameQueryContext with pysparkLoggingInfo - leq @@ -153,7 +153,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="leq", + fragment="__le__", ) # DataFrameQueryContext with pysparkLoggingInfo - geq @@ -169,7 +169,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="geq", + fragment="__ge__", ) # DataFrameQueryContext with pysparkLoggingInfo - gt @@ -185,7 +185,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="gt", + fragment="__gt__", ) # DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe @@ -201,7 +201,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="eqNullSafe", + fragment="eqNullSafe", ) # DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR @@ -217,7 +217,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="bitwiseOR", + fragment="bitwiseOR", ) # DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND @@ -233,7 +233,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="bitwiseAND", + fragment="bitwiseAND", ) # DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR @@ -249,7 +249,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="bitwiseXOR", + fragment="bitwiseXOR", ) # DataFrameQueryContext with pysparkLoggingInfo - chained (`divide` is problematic) @@ -262,7 +262,7 @@ def test_dataframe_query_context(self): error_class="DIVIDE_BY_ZERO", message_parameters={"config": '"spark.sql.ansi.enabled"'}, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="divide", + fragment="__truediv__", ) # DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` is problematic) @@ -282,7 +282,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="plus", + fragment="__add__", ) # DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` is problematic) @@ -302,7 +302,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="minus", + fragment="__sub__", ) # DataFrameQueryContext with pysparkLoggingInfo - chained (`multiply` is problematic) @@ -320,7 +320,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="multiply", + fragment="__mul__", ) # Multiple expressions in df.select (`divide` is problematic) @@ -331,7 +331,7 @@ def test_dataframe_query_context(self): error_class="DIVIDE_BY_ZERO", message_parameters={"config": '"spark.sql.ansi.enabled"'}, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="divide", + fragment="__truediv__", ) # Multiple expressions in df.select (`plus` is problematic) @@ -347,7 +347,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="plus", + fragment="__add__", ) # Multiple expressions in df.select (`minus` is problematic) @@ -363,7 +363,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="minus", + fragment="__sub__", ) # Multiple expressions in df.select (`multiply` is problematic) @@ -379,7 +379,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="multiply", + fragment="__mul__", ) # Multiple expressions with pre-declared expressions (`divide` is problematic) @@ -392,7 +392,7 @@ def test_dataframe_query_context(self): error_class="DIVIDE_BY_ZERO", message_parameters={"config": '"spark.sql.ansi.enabled"'}, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="divide", + fragment="__truediv__", ) # Multiple expressions with pre-declared expressions (`plus` is problematic) @@ -410,7 +410,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="plus", + fragment="__add__", ) # Multiple expressions with pre-declared expressions (`minus` is problematic) @@ -428,7 +428,7 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="minus", + fragment="__sub__", ) # Multiple expressions with pre-declared expressions (`multiply` is problematic) @@ -446,20 +446,11 @@ def test_dataframe_query_context(self): "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, - pyspark_fragment="multiply", - ) - - # DataFrameQueryContext without pysparkLoggingInfo - with self.assertRaises(AnalysisException) as pe: - df.select("non-existing-column") - self.check_error( - exception=pe.exception, - error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION", - message_parameters={"objectName": "`non-existing-column`", "proposal": "`id`"}, - query_context_type=QueryContextType.DataFrame, - pyspark_fragment="", + fragment="__mul__", ) + def test_sql_query_context(self): + with self.sql_conf({"spark.sql.ansi.enabled": True}): # SQLQueryContext with self.assertRaises(ArithmeticException) as pe: self.spark.sql("select 10/0").collect() diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index fa58b7286fe88..c74291524daed 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -287,7 +287,7 @@ def check_error( error_class: str, message_parameters: Optional[Dict[str, str]] = None, query_context_type: Optional[QueryContextType] = None, - pyspark_fragment: Optional[str] = None, + fragment: Optional[str] = None, ): query_context = exception.getQueryContext() assert bool(query_context) == (query_context_type is not None), ( @@ -326,10 +326,10 @@ def check_error( ) if actual == QueryContextType.DataFrame: assert ( - pyspark_fragment is not None - ), "`pyspark_fragment` is required when QueryContextType is DataFrame." - expected = pyspark_fragment - actual = actual_context.pysparkFragment() + fragment is not None + ), "`fragment` is required when QueryContextType is DataFrame." + expected = fragment + actual = actual_context.fragment() self.assertEqual( expected, actual, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 1c2456f00bcdc..2b3f4674539e3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -145,36 +145,30 @@ case class DataFrameQueryContext( override def stopIndex: Int = throw SparkUnsupportedOperationException() override val fragment: String = { - stackTrace.headOption.map { firstElem => - val methodName = firstElem.getMethodName - if (methodName.length > 1 && methodName(0) == '$') { - methodName.substring(1) - } else { - methodName - } - }.getOrElse("") + pysparkErrorContext.map(_._1).getOrElse { + stackTrace.headOption.map { firstElem => + val methodName = firstElem.getMethodName + if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + }.getOrElse("") + } } - override val callSite: String = stackTrace.tail.mkString("\n") - - val pysparkFragment: String = pysparkErrorContext.map(_._1).getOrElse("") - val pysparkCallSite: String = pysparkErrorContext.map(_._2).getOrElse("") - - val (displayedFragment, displayedCallsite) = if (pysparkErrorContext.nonEmpty) { - (pysparkFragment, pysparkCallSite) - } else { - (fragment, callSite) - } + override val callSite: String = pysparkErrorContext.map( + _._2).getOrElse(stackTrace.tail.mkString("\n")) override lazy val summary: String = { val builder = new StringBuilder builder ++= "== DataFrame ==\n" builder ++= "\"" - builder ++= displayedFragment + builder ++= fragment builder ++= "\"" builder ++= " was called from\n" - builder ++= displayedCallsite + builder ++= callSite builder += '\n' builder.result()