Skip to content

Commit 8c3194f

Browse files
jon-miohvanhovell
authored andcommitted
[SPARK-52348][CONNECT] Add support for Spark Connect handlers for pipeline commands
### What changes were proposed in this pull request? - Introduces a `PipelinesHandler` which handles SparkConnect PipelineCommands. This follows the pattern of `MLHandler` where the `SparkConnectPlanner` delegates any ML commands to the `MLHandler` - Stream `PipelineEvent`s that are emitted during pipeline execution back to the SparkConnect client - Rethrow exceptions that occur during pipeline execution in the `StartRun` handler so that they are automatically propagated back to the SC client This is PR builds off changes in a few open PRs. I have squashed those changes into a single commit at the top of this PR - 49626fb2e6af0e5a2df5e3ad361f6e98b88ad297. **When reviewing please ignore that commit and just review all commits after that one.** Misc changes: - Convert to timestamp field in `PipelineEvent` proto from `String` to `google.protobuf.Timestamp` - Remove references to `SerializedException` and `ErrorDetail` in favor of representing errors just as `Throwable` ### Why are the changes needed? This change is needed to support Spark Declarative Pipelines. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? New unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #51057 from jonmio/sc-pipelines. Authored-by: Jon Mio <jon.mio@databricks.com> Signed-off-by: Herman van Hovell <herman@databricks.com>
1 parent 1a3ae66 commit 8c3194f

File tree

30 files changed

+1917
-436
lines changed

30 files changed

+1917
-436
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,12 @@
921921
},
922922
"sqlState" : "21S01"
923923
},
924+
"DATAFLOW_GRAPH_NOT_FOUND" : {
925+
"message" : [
926+
"Dataflow graph with id <graphId> could not be found"
927+
],
928+
"sqlState" : "KD011"
929+
},
924930
"DATATYPE_MISMATCH" : {
925931
"message" : [
926932
"Cannot resolve <sqlExpr> due to data type mismatch:"

common/utils/src/main/resources/error/error-states.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7470,6 +7470,12 @@
74707470
"standard": "N",
74717471
"usedBy": ["Databricks"]
74727472
},
7473+
"KD011": {
7474+
"description": "dataflow graph not found",
7475+
"origin": "Databricks",
7476+
"standard": "N",
7477+
"usedBy": ["Databricks"]
7478+
},
74737479
"P0000": {
74747480
"description": "procedural logic error",
74757481
"origin": "PostgreSQL",

python/pyspark/pipelines/cli.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
GraphElementRegistry,
3838
)
3939
from pyspark.pipelines.init_cli import init
40+
from pyspark.pipelines.logging_utils import log_with_curr_timestamp
4041
from pyspark.pipelines.spark_connect_graph_element_registry import (
4142
SparkConnectGraphElementRegistry,
4243
)
@@ -163,14 +164,16 @@ def register_definitions(
163164
path = spec_path.parent
164165
with change_dir(path):
165166
with graph_element_registration_context(registry):
166-
print(f"Loading definitions. Root directory: '{path}'.")
167+
log_with_curr_timestamp(f"Loading definitions. Root directory: '{path}'.")
167168
for definition_glob in spec.definitions:
168169
glob_expression = definition_glob.include
169170
matching_files = [p for p in path.glob(glob_expression) if p.is_file()]
170-
print(f"Found {len(matching_files)} files matching glob '{glob_expression}'")
171+
log_with_curr_timestamp(
172+
f"Found {len(matching_files)} files matching glob '{glob_expression}'"
173+
)
171174
for file in matching_files:
172175
if file.suffix == ".py":
173-
print(f"Importing {file}...")
176+
log_with_curr_timestamp(f"Importing {file}...")
174177
module_spec = importlib.util.spec_from_file_location(file.stem, str(file))
175178
assert module_spec is not None, f"Could not find module spec for {file}"
176179
module = importlib.util.module_from_spec(module_spec)
@@ -179,7 +182,7 @@ def register_definitions(
179182
), f"Module spec has no loader for {file}"
180183
module_spec.loader.exec_module(module)
181184
elif file.suffix == ".sql":
182-
print(f"Registering SQL file {file}...")
185+
log_with_curr_timestamp(f"Registering SQL file {file}...")
183186
with file.open("r") as f:
184187
sql = f.read()
185188
file_path_relative_to_spec = file.relative_to(spec_path.parent)
@@ -204,29 +207,29 @@ def change_dir(path: Path) -> Generator[None, None, None]:
204207

205208
def run(spec_path: Path, remote: str) -> None:
206209
"""Run the pipeline defined with the given spec."""
207-
print(f"Loading pipeline spec from {spec_path}...")
210+
log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...")
208211
spec = load_pipeline_spec(spec_path)
209212

210-
print("Creating Spark session...")
213+
log_with_curr_timestamp("Creating Spark session...")
211214
spark_builder = SparkSession.builder.remote(remote)
212215
for key, value in spec.configuration.items():
213216
spark_builder = spark_builder.config(key, value)
214217

215218
spark = spark_builder.create()
216219

217-
print("Creating dataflow graph...")
220+
log_with_curr_timestamp("Creating dataflow graph...")
218221
dataflow_graph_id = create_dataflow_graph(
219222
spark,
220223
default_catalog=spec.catalog,
221224
default_database=spec.database,
222225
sql_conf=spec.configuration,
223226
)
224227

225-
print("Registering graph elements...")
228+
log_with_curr_timestamp("Registering graph elements...")
226229
registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
227230
register_definitions(spec_path, registry, spec)
228231

229-
print("Starting run...")
232+
log_with_curr_timestamp("Starting run...")
230233
result_iter = start_run(spark, dataflow_graph_id)
231234
try:
232235
handle_pipeline_events(result_iter)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from datetime import datetime
19+
20+
21+
def log_with_curr_timestamp(message: str) -> None:
22+
"""
23+
Print a log message with a formatted timestamp corresponding to the current time. Note that
24+
currently only the UTC timezone is supported, but we plan to eventually support the timezone
25+
specified by the SESSION_LOCAL_TIMEZONE SQL conf.
26+
27+
Args:
28+
message (str): The message to log
29+
"""
30+
log_with_provided_timestamp(message, datetime.now())
31+
32+
33+
def log_with_provided_timestamp(message: str, timestamp: datetime) -> None:
34+
"""
35+
Print a log message with a formatted timestamp prefix.
36+
37+
Args:
38+
message (str): The message to log
39+
timestamp(datetime): The timestamp to use for the log message.
40+
"""
41+
formatted_timestamp = timestamp.strftime("%Y-%m-%d %H:%M:%S")
42+
print(f"{formatted_timestamp}: {message}")

python/pyspark/pipelines/spark_connect_graph_element_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def register_flow(self, flow: Flow) -> None:
102102
self._client.execute_command(command)
103103

104104
def register_sql(self, sql_text: str, file_path: Path) -> None:
105-
inner_command = pb2.DefineSqlGraphElements(
105+
inner_command = pb2.PipelineCommand.DefineSqlGraphElements(
106106
dataflow_graph_id=self._dataflow_graph_id,
107107
sql_text=sql_text,
108108
sql_file_path=str(file_path),

python/pyspark/pipelines/spark_connect_pipeline.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
from datetime import timezone
1718
from typing import Any, Dict, Mapping, Iterator, Optional, cast
1819

1920
import pyspark.sql.connect.proto as pb2
2021
from pyspark.sql import SparkSession
2122
from pyspark.errors.exceptions.base import PySparkValueError
23+
from pyspark.pipelines.logging_utils import log_with_provided_timestamp
2224

2325

2426
def create_dataflow_graph(
@@ -53,13 +55,14 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None:
5355
# We expect to get a pipeline_command_result back in response to the initial StartRun
5456
# command.
5557
continue
56-
elif "pipeline_events_result" not in result.keys():
58+
elif "pipeline_event_result" not in result.keys():
5759
raise PySparkValueError(
5860
"Pipeline logs stream handler received an unexpected result: " f"{result}"
5961
)
6062
else:
61-
for e in result["pipeline_events_result"].events:
62-
print(f"{e.timestamp}: {e.message}")
63+
event = result["pipeline_event_result"].event
64+
dt = event.timestamp.ToDatetime().replace(tzinfo=timezone.utc)
65+
log_with_provided_timestamp(event.message, dt)
6366

6467

6568
def start_run(spark: SparkSession, dataflow_graph_id: str) -> Iterator[Dict[str, Any]]:

python/pyspark/sql/connect/client/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,10 @@ def handle_response(
14741474
if b.HasField("streaming_query_listener_events_result"):
14751475
event_result = b.streaming_query_listener_events_result
14761476
yield {"streaming_query_listener_events_result": event_result}
1477+
if b.HasField("pipeline_command_result"):
1478+
yield {"pipeline_command_result": b.pipeline_command_result}
1479+
if b.HasField("pipeline_event_result"):
1480+
yield {"pipeline_event_result": b.pipeline_event_result}
14771481
if b.HasField("get_resources_command_result"):
14781482
resources = {}
14791483
for key, resource in b.get_resources_command_result.resources.items():

0 commit comments

Comments
 (0)