Skip to content

Commit

Permalink
Add spark3-submit to list of allowed spark-binary values (#30068)
Browse files Browse the repository at this point in the history
* Add spark3-submit to list of allowed spark-binary values

The list of allowed values for spark-binary was restricted in
#27646.  Add spark3-submit to this list to allow for distributions
of Spark 3 that install the binary this way.

See also #30065.

* Fix lint errors in spark.rst and test_spark_submit.py
  • Loading branch information
ottomata committed Mar 15, 2023
1 parent 297f4b3 commit b325987
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
11 changes: 6 additions & 5 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
with contextlib.suppress(ImportError, NameError):
from airflow.kubernetes import kube_client

ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit"]
ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit", "spark3-submit"]


class SparkSubmitHook(BaseHook, LoggingMixin):
Expand Down Expand Up @@ -78,7 +78,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
supports yarn and k8s mode too.
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit.
Some distros may use spark2-submit or spark3-submit.
"""

conn_name_attr = "conn_id"
Expand Down Expand Up @@ -206,15 +206,16 @@ def _resolve_connection(self) -> dict[str, Any]:
spark_binary = self._spark_binary or extra.get("spark-binary", "spark-submit")
if spark_binary not in ALLOWED_SPARK_BINARIES:
raise RuntimeError(
f"The `spark-binary` extra can be on of {ALLOWED_SPARK_BINARIES} and it"
f"The `spark-binary` extra can be one of {ALLOWED_SPARK_BINARIES} and it"
f" was `{spark_binary}`. Please make sure your spark binary is one of the"
" allowed ones and that it is available on the PATH"
)
conn_spark_home = extra.get("spark-home")
if conn_spark_home:
raise RuntimeError(
"The `spark-home` extra is not allowed any more. Please make sure your `spark-submit` or"
" `spark2-submit` are available on the PATH."
"The `spark-home` extra is not allowed any more. Please make sure one of"
f" {ALLOWED_SPARK_BINARIES} is available on the PATH, and set `spark-binary`"
" if needed."
)
conn_data["spark_binary"] = spark_binary
conn_data["namespace"] = extra.get("namespace")
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/spark/operators/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class SparkSubmitOperator(BaseOperator):
:param env_vars: Environment variables for spark-submit. It supports yarn and k8s mode too. (templated)
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit.
Some distros may use spark2-submit or spark3-submit.
"""

template_fields: Sequence[str] = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Extra (optional)

* ``queue`` - The name of the YARN queue to which the application is submitted.
* ``deploy-mode`` - Whether to deploy your driver on the worker nodes (cluster) or locally as an external client (client).
* ``spark-binary`` - The command to use for Spark submit. Some distros may use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit`` and ``spark2-submit`` are allowed as value.
* ``spark-binary`` - The command to use for Spark submit. Some distros may use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit``, ``spark2-submit`` or ``spark3-submit`` are allowed as value.
* ``namespace`` - Kubernetes namespace (``spark.kubernetes.namespace``) to divide cluster resources between multiple users (via resource quota).

When specifying the connection in environment variable you should specify
Expand Down
33 changes: 30 additions & 3 deletions tests/providers/apache/spark/hooks/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ def setup_method(self):
extra='{"spark-binary": "spark2-submit"}',
)
)
db.merge_conn(
Connection(
conn_id="spark_binary_set_spark3_submit",
conn_type="spark",
host="yarn",
extra='{"spark-binary": "spark3-submit"}',
)
)
db.merge_conn(
Connection(
conn_id="spark_custom_binary_set",
Expand Down Expand Up @@ -434,6 +442,25 @@ def test_resolve_connection_spark_binary_set_connection(self):
assert connection == expected_spark_connection
assert cmd[0] == "spark2-submit"

def test_resolve_connection_spark_binary_spark3_submit_set_connection(self):
# Given
hook = SparkSubmitHook(conn_id="spark_binary_set_spark3_submit")

# When
connection = hook._resolve_connection()
cmd = hook._build_spark_submit_command(self._spark_job_file)

# Then
expected_spark_connection = {
"master": "yarn",
"spark_binary": "spark3-submit",
"deploy_mode": None,
"queue": None,
"namespace": None,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark3-submit"

def test_resolve_connection_custom_spark_binary_not_allowed_runtime_error(self):
with pytest.raises(RuntimeError):
SparkSubmitHook(conn_id="spark_binary_set", spark_binary="another-custom-spark-submit")
Expand All @@ -448,7 +475,7 @@ def test_resolve_connection_spark_home_not_allowed_runtime_error(self):

def test_resolve_connection_spark_binary_default_value_override(self):
# Given
hook = SparkSubmitHook(conn_id="spark_binary_set", spark_binary="spark2-submit")
hook = SparkSubmitHook(conn_id="spark_binary_set", spark_binary="spark3-submit")

# When
connection = hook._resolve_connection()
Expand All @@ -457,13 +484,13 @@ def test_resolve_connection_spark_binary_default_value_override(self):
# Then
expected_spark_connection = {
"master": "yarn",
"spark_binary": "spark2-submit",
"spark_binary": "spark3-submit",
"deploy_mode": None,
"queue": None,
"namespace": None,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark2-submit"
assert cmd[0] == "spark3-submit"

def test_resolve_connection_spark_binary_default_value(self):
# Given
Expand Down

0 comments on commit b325987

Please sign in to comment.