From 13a827d80fef738e25f30ea20c095ad4dbd401f6 Mon Sep 17 00:00:00 2001 From: Aneesh Joseph Date: Thu, 9 Jul 2020 15:09:16 +0530 Subject: [PATCH] Ensure Kerberos token is valid in SparkSubmitOperator before running `yarn kill` (#9044) do a kinit before yarn kill if keytab and principal is provided --- .../apache/spark/hooks/spark_submit.py | 16 ++++++++++--- airflow/security/kerberos.py | 16 +++++++++---- .../apache/spark/hooks/test_spark_submit.py | 23 +++++++++++++++++-- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index d94be2cc023cc..f7cf839ef63a5 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -21,8 +21,10 @@ import subprocess import time +from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook +from airflow.security.kerberos import renew_from_kt from airflow.utils.log.logging_mixin import LoggingMixin try: @@ -617,15 +619,23 @@ def on_kill(self): self._submit_sp.kill() if self._yarn_application_id: - self.log.info('Killing application %s on YARN', self._yarn_application_id) - kill_cmd = "yarn application -kill {}" \ .format(self._yarn_application_id).split() + env = None + if self._keytab is not None and self._principal is not None: + # we are ignoring renewal failures from renew_from_kt + # here as the failure could just be due to a non-renewable ticket, + # we still attempt to kill the yarn application + renew_from_kt(self._principal, self._keytab, exit_on_fail=False) + env = os.environ.copy() + env["KRB5CCNAME"] = airflow_conf.get('kerberos', 'ccache') + yarn_kill = subprocess.Popen(kill_cmd, + env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - self.log.info("YARN killed with return code: %s", yarn_kill.wait()) + self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) if self._kubernetes_driver_pod: self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod) diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py index 2886b6161ef6e..72aff08c2f886 100644 --- a/airflow/security/kerberos.py +++ b/airflow/security/kerberos.py @@ -46,7 +46,7 @@ log = logging.getLogger(__name__) -def renew_from_kt(principal: str, keytab: str): +def renew_from_kt(principal: str, keytab: str, exit_on_fail: bool = True): """ Renew kerberos token from keytab @@ -86,7 +86,10 @@ def renew_from_kt(principal: str, keytab: str): "\n".join(subp.stdout.readlines() if subp.stdout else []), "\n".join(subp.stderr.readlines() if subp.stderr else []) ) - sys.exit(subp.returncode) + if exit_on_fail: + sys.exit(subp.returncode) + else: + return subp.returncode global NEED_KRB181_WORKAROUND # pylint: disable=global-statement if NEED_KRB181_WORKAROUND is None: @@ -95,7 +98,12 @@ def renew_from_kt(principal: str, keytab: str): # (From: HUE-640). Kerberos clock have seconds level granularity. Make sure we # renew the ticket after the initial valid time. time.sleep(1.5) - perform_krb181_workaround(principal) + ret = perform_krb181_workaround(principal) + if exit_on_fail and ret != 0: + sys.exit(ret) + else: + return ret + return 0 def perform_krb181_workaround(principal: str): @@ -127,7 +135,7 @@ def perform_krb181_workaround(principal: str): "configuration, and the ticket renewal policy (maxrenewlife) for the '%s' and `krbtgt' " "principals.", princ, ccache, princ ) - sys.exit(ret) + return ret def detect_conf_var() -> bool: diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py b/tests/providers/apache/spark/hooks/test_spark_submit.py index df22a61e40b41..570e54cfd1723 100644 --- a/tests/providers/apache/spark/hooks/test_spark_submit.py +++ b/tests/providers/apache/spark/hooks/test_spark_submit.py @@ -17,6 +17,7 @@ # under the License. import io +import os import unittest from unittest.mock import call, patch @@ -651,8 +652,9 @@ def test_process_spark_driver_status_log(self): self.assertEqual(hook._driver_status, 'RUNNING') + @patch('airflow.providers.apache.spark.hooks.spark_submit.renew_from_kt') @patch('airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen') - def test_yarn_process_on_kill(self, mock_popen): + def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt): # Given mock_popen.return_value.stdout = io.StringIO('stdout') mock_popen.return_value.stderr = io.StringIO('stderr') @@ -679,7 +681,24 @@ def test_yarn_process_on_kill(self, mock_popen): # Then self.assertIn(call(['yarn', 'application', '-kill', 'application_1486558679801_1820'], - stderr=-1, stdout=-1), + env=None, stderr=-1, stdout=-1), + mock_popen.mock_calls) + # resetting the mock to test kill with keytab & principal + mock_popen.reset_mock() + # Given + hook = SparkSubmitHook(conn_id='spark_yarn_cluster', keytab='privileged_user.keytab', + principal='user/spark@airflow.org') + hook._process_spark_submit_log(log_lines) + hook.submit() + + # When + hook.on_kill() + # Then + expected_env = os.environ.copy() + expected_env["KRB5CCNAME"] = '/tmp/airflow_krb5_ccache' + self.assertIn(call(['yarn', 'application', '-kill', + 'application_1486558679801_1820'], + env=expected_env, stderr=-1, stdout=-1), mock_popen.mock_calls) def test_standalone_cluster_process_on_kill(self):