Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change conf property from str to dict in SparkSqlOperator #40527

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
10 changes: 7 additions & 3 deletions airflow/providers/apache/spark/hooks/spark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SparkSqlHook(BaseHook):
This hook is a wrapper around the spark-sql binary; requires the "spark-sql" binary to be in the PATH.

:param sql: The SQL query to execute
:param conf: arbitrary Spark configuration property
:param conf: arbitrary Spark configuration properties
:param conn_id: connection_id string
:param total_executor_cores: (Standalone & Mesos only) Total cores for all executors
(Default: all the available cores on the worker)
Expand Down Expand Up @@ -82,7 +82,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
def __init__(
self,
sql: str,
conf: str | None = None,
conf: dict[str, Any] | str | None = None,
conn_id: str = default_conn_name,
total_executor_cores: int | None = None,
executor_cores: int | None = None,
Expand Down Expand Up @@ -142,7 +142,11 @@ def _prepare_command(self, cmd: str | list[str]) -> list[str]:
:return: full command to be executed
"""
connection_cmd = ["spark-sql"]
if self._conf:
if isinstance(self._conf, dict):
for conf_el in self._conf:
connection_cmd += ["--conf", conf_el]
# Keep compatibility with older versions
elif isinstance(self._conf, str):
for conf_el in self._conf.split(","):
connection_cmd += ["--conf", conf_el]
if self._total_executor_cores:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/spark/operators/spark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SparkSqlOperator(BaseOperator):
:ref:`howto/operator:SparkSqlOperator`

:param sql: The SQL query to execute. (templated)
:param conf: arbitrary Spark configuration property
:param conf: arbitrary Spark configuration properties
:param conn_id: connection_id string
:param total_executor_cores: (Standalone & Mesos only) Total cores for all
executors (Default: all the available cores on the worker)
Expand All @@ -63,7 +63,7 @@ def __init__(
self,
*,
sql: str,
conf: str | None = None,
conf: dict[str, Any] | str | None = None,
conn_id: str = "spark_sql_default",
total_executor_cores: int | None = None,
executor_cores: int | None = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. Licensed to the Apache Software Foundation (ASF) under one
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/apache/spark/hooks/test_spark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TestSparkSqlHook:
"num_executors": 10,
"verbose": True,
"sql": " /path/to/sql/file.sql ",
"conf": "key=value,PROP=VALUE",
"conf": {"key": "value", "PROP": "VALUE"},
}

@classmethod
Expand Down Expand Up @@ -77,8 +77,7 @@ def test_build_command(self):
assert self._config["sql"].strip() == sql_path

# Check if all config settings are there
for key_value in self._config["conf"].split(","):
k, v = key_value.split("=")
for k, v in self._config["conf"].items():
assert f"--conf {k}={v}" in cmd

if self._config["verbose"]:
Expand Down
Loading