Skip to content

Commit

Permalink
Update Spark submit operator for Spark 3 support (#8730)
Browse files Browse the repository at this point in the history
In spark 3 they log the exit code with a lowercase
e, in spark 2 they used an uppercase E.

Also made the exception a bit clearer when running
on kubernetes.
  • Loading branch information
stijndehaes committed Jul 22, 2020
1 parent 040fb1d commit 1427e4a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
17 changes: 12 additions & 5 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,18 @@ def submit(self, application: str = "", **kwargs: Any) -> None:
# Check spark-submit return code. In Kubernetes mode, also check the value
# of exit code in the log, as it may differ.
if returncode or (self._is_kubernetes and self._spark_exit_code != 0):
raise AirflowException(
"Cannot execute: {}. Error code is: {}.".format(
self._mask_cmd(spark_submit_cmd), returncode
if self._is_kubernetes:
raise AirflowException(
"Cannot execute: {}. Error code is: {}. Kubernetes spark exit code is: {}".format(
self._mask_cmd(spark_submit_cmd), returncode, self._spark_exit_code
)
)
else:
raise AirflowException(
"Cannot execute: {}. Error code is: {}.".format(
self._mask_cmd(spark_submit_cmd), returncode
)
)
)

self.log.debug("Should track driver: %s", self._should_track_driver_status)

Expand Down Expand Up @@ -485,7 +492,7 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
self._kubernetes_driver_pod)

# Store the Spark Exit code
match_exit_code = re.search(r'\s*Exit code: (\d+)', line)
match_exit_code = re.search(r'\s*[eE]xit code: (\d+)', line)
if match_exit_code:
self._spark_exit_code = int(match_exit_code.groups()[0])

Expand Down
13 changes: 13 additions & 0 deletions tests/providers/apache/spark/hooks/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,19 @@ def test_process_spark_submit_log_k8s(self):
'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver')
self.assertEqual(hook._spark_exit_code, 999)

def test_process_spark_submit_log_k8s_spark_3(self):
# Given
hook = SparkSubmitHook(conn_id='spark_k8s_cluster')
log_lines = [
'exit code: 999'
]

# When
hook._process_spark_submit_log(log_lines)

# Then
self.assertEqual(hook._spark_exit_code, 999)

def test_process_spark_submit_log_standalone_cluster(self):
# Given
hook = SparkSubmitHook(conn_id='spark_standalone_cluster')
Expand Down

0 comments on commit 1427e4a

Please sign in to comment.