Skip to content

[SPARK-52348][CONNECT] Add support for Spark Connect handlers for pipeline commands #51057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
@@ -921,6 +921,12 @@
},
"sqlState" : "21S01"
},
"DATAFLOW_GRAPH_NOT_FOUND" : {
"message" : [
"Dataflow graph with id <graphId> could not be found"
],
"sqlState" : "KD011"
},
"DATATYPE_MISMATCH" : {
"message" : [
"Cannot resolve <sqlExpr> due to data type mismatch:"
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-states.json
Original file line number Diff line number Diff line change
@@ -7470,6 +7470,12 @@
"standard": "N",
"usedBy": ["Databricks"]
},
"KD011": {
"description": "dataflow graph not found",
"origin": "Databricks",
"standard": "N",
"usedBy": ["Databricks"]
},
"P0000": {
"description": "procedural logic error",
"origin": "PostgreSQL",
21 changes: 12 additions & 9 deletions python/pyspark/pipelines/cli.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@
GraphElementRegistry,
)
from pyspark.pipelines.init_cli import init
from pyspark.pipelines.logging_utils import log_with_curr_timestamp
from pyspark.pipelines.spark_connect_graph_element_registry import (
SparkConnectGraphElementRegistry,
)
@@ -163,14 +164,16 @@ def register_definitions(
path = spec_path.parent
with change_dir(path):
with graph_element_registration_context(registry):
print(f"Loading definitions. Root directory: '{path}'.")
log_with_curr_timestamp(f"Loading definitions. Root directory: '{path}'.")
for definition_glob in spec.definitions:
glob_expression = definition_glob.include
matching_files = [p for p in path.glob(glob_expression) if p.is_file()]
print(f"Found {len(matching_files)} files matching glob '{glob_expression}'")
log_with_curr_timestamp(
f"Found {len(matching_files)} files matching glob '{glob_expression}'"
)
for file in matching_files:
if file.suffix == ".py":
print(f"Importing {file}...")
log_with_curr_timestamp(f"Importing {file}...")
module_spec = importlib.util.spec_from_file_location(file.stem, str(file))
assert module_spec is not None, f"Could not find module spec for {file}"
module = importlib.util.module_from_spec(module_spec)
@@ -179,7 +182,7 @@ def register_definitions(
), f"Module spec has no loader for {file}"
module_spec.loader.exec_module(module)
elif file.suffix == ".sql":
print(f"Registering SQL file {file}...")
log_with_curr_timestamp(f"Registering SQL file {file}...")
with file.open("r") as f:
sql = f.read()
file_path_relative_to_spec = file.relative_to(spec_path.parent)
@@ -204,29 +207,29 @@ def change_dir(path: Path) -> Generator[None, None, None]:

def run(spec_path: Path, remote: str) -> None:
"""Run the pipeline defined with the given spec."""
print(f"Loading pipeline spec from {spec_path}...")
log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...")
spec = load_pipeline_spec(spec_path)

print("Creating Spark session...")
log_with_curr_timestamp("Creating Spark session...")
spark_builder = SparkSession.builder.remote(remote)
for key, value in spec.configuration.items():
spark_builder = spark_builder.config(key, value)

spark = spark_builder.create()

print("Creating dataflow graph...")
log_with_curr_timestamp("Creating dataflow graph...")
dataflow_graph_id = create_dataflow_graph(
spark,
default_catalog=spec.catalog,
default_database=spec.database,
sql_conf=spec.configuration,
)

print("Registering graph elements...")
log_with_curr_timestamp("Registering graph elements...")
registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
register_definitions(spec_path, registry, spec)

print("Starting run...")
log_with_curr_timestamp("Starting run...")
result_iter = start_run(spark, dataflow_graph_id)
try:
handle_pipeline_events(result_iter)
42 changes: 42 additions & 0 deletions python/pyspark/pipelines/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from datetime import datetime


def log_with_curr_timestamp(message: str) -> None:
"""
Print a log message with a formatted timestamp corresponding to the current time. Note that
currently only the UTC timezone is supported, but we plan to eventually support the timezone
specified by the SESSION_LOCAL_TIMEZONE SQL conf.

Args:
message (str): The message to log
"""
log_with_provided_timestamp(message, datetime.now())


def log_with_provided_timestamp(message: str, timestamp: datetime) -> None:
"""
Print a log message with a formatted timestamp prefix.

Args:
message (str): The message to log
timestamp(datetime): The timestamp to use for the log message.
"""
formatted_timestamp = timestamp.strftime("%Y-%m-%d %H:%M:%S")
print(f"{formatted_timestamp}: {message}")
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ def register_flow(self, flow: Flow) -> None:
self._client.execute_command(command)

def register_sql(self, sql_text: str, file_path: Path) -> None:
inner_command = pb2.DefineSqlGraphElements(
inner_command = pb2.PipelineCommand.DefineSqlGraphElements(
dataflow_graph_id=self._dataflow_graph_id,
sql_text=sql_text,
sql_file_path=str(file_path),
9 changes: 6 additions & 3 deletions python/pyspark/pipelines/spark_connect_pipeline.py
Original file line number Diff line number Diff line change
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import timezone
from typing import Any, Dict, Mapping, Iterator, Optional, cast

import pyspark.sql.connect.proto as pb2
from pyspark.sql import SparkSession
from pyspark.errors.exceptions.base import PySparkValueError
from pyspark.pipelines.logging_utils import log_with_provided_timestamp


def create_dataflow_graph(
@@ -53,13 +55,14 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None:
# We expect to get a pipeline_command_result back in response to the initial StartRun
# command.
continue
elif "pipeline_events_result" not in result.keys():
elif "pipeline_event_result" not in result.keys():
raise PySparkValueError(
"Pipeline logs stream handler received an unexpected result: " f"{result}"
)
else:
for e in result["pipeline_events_result"].events:
print(f"{e.timestamp}: {e.message}")
event = result["pipeline_event_result"].event
dt = event.timestamp.ToDatetime().replace(tzinfo=timezone.utc)
log_with_provided_timestamp(event.message, dt)


def start_run(spark: SparkSession, dataflow_graph_id: str) -> Iterator[Dict[str, Any]]:
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
@@ -1474,6 +1474,10 @@ def handle_response(
if b.HasField("streaming_query_listener_events_result"):
event_result = b.streaming_query_listener_events_result
yield {"streaming_query_listener_events_result": event_result}
if b.HasField("pipeline_command_result"):
yield {"pipeline_command_result": b.pipeline_command_result}
if b.HasField("pipeline_event_result"):
yield {"pipeline_event_result": b.pipeline_event_result}
if b.HasField("get_resources_command_result"):
resources = {}
for key, resource in b.get_resources_command_result.resources.items():
Loading
Oops, something went wrong.