Skip to content

Commit

Permalink
Ensure Kerberos token is valid in SparkSubmitOperator before running …
Browse files Browse the repository at this point in the history
…`yarn kill` (#9044)

do a kinit before yarn kill if keytab and principal is provided
  • Loading branch information
aneesh-joseph committed Jul 9, 2020
1 parent 8b94ace commit 13a827d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
16 changes: 13 additions & 3 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import subprocess
import time

from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow.security.kerberos import renew_from_kt
from airflow.utils.log.logging_mixin import LoggingMixin

try:
Expand Down Expand Up @@ -617,15 +619,23 @@ def on_kill(self):
self._submit_sp.kill()

if self._yarn_application_id:
self.log.info('Killing application %s on YARN', self._yarn_application_id)

kill_cmd = "yarn application -kill {}" \
.format(self._yarn_application_id).split()
env = None
if self._keytab is not None and self._principal is not None:
# we are ignoring renewal failures from renew_from_kt
# here as the failure could just be due to a non-renewable ticket,
# we still attempt to kill the yarn application
renew_from_kt(self._principal, self._keytab, exit_on_fail=False)
env = os.environ.copy()
env["KRB5CCNAME"] = airflow_conf.get('kerberos', 'ccache')

yarn_kill = subprocess.Popen(kill_cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)

self.log.info("YARN killed with return code: %s", yarn_kill.wait())
self.log.info("YARN app killed with return code: %s", yarn_kill.wait())

if self._kubernetes_driver_pod:
self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod)
Expand Down
16 changes: 12 additions & 4 deletions airflow/security/kerberos.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
log = logging.getLogger(__name__)


def renew_from_kt(principal: str, keytab: str):
def renew_from_kt(principal: str, keytab: str, exit_on_fail: bool = True):
"""
Renew kerberos token from keytab
Expand Down Expand Up @@ -86,7 +86,10 @@ def renew_from_kt(principal: str, keytab: str):
"\n".join(subp.stdout.readlines() if subp.stdout else []),
"\n".join(subp.stderr.readlines() if subp.stderr else [])
)
sys.exit(subp.returncode)
if exit_on_fail:
sys.exit(subp.returncode)
else:
return subp.returncode

global NEED_KRB181_WORKAROUND # pylint: disable=global-statement
if NEED_KRB181_WORKAROUND is None:
Expand All @@ -95,7 +98,12 @@ def renew_from_kt(principal: str, keytab: str):
# (From: HUE-640). Kerberos clock have seconds level granularity. Make sure we
# renew the ticket after the initial valid time.
time.sleep(1.5)
perform_krb181_workaround(principal)
ret = perform_krb181_workaround(principal)
if exit_on_fail and ret != 0:
sys.exit(ret)
else:
return ret
return 0


def perform_krb181_workaround(principal: str):
Expand Down Expand Up @@ -127,7 +135,7 @@ def perform_krb181_workaround(principal: str):
"configuration, and the ticket renewal policy (maxrenewlife) for the '%s' and `krbtgt' "
"principals.", princ, ccache, princ
)
sys.exit(ret)
return ret


def detect_conf_var() -> bool:
Expand Down
23 changes: 21 additions & 2 deletions tests/providers/apache/spark/hooks/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import io
import os
import unittest
from unittest.mock import call, patch

Expand Down Expand Up @@ -651,8 +652,9 @@ def test_process_spark_driver_status_log(self):

self.assertEqual(hook._driver_status, 'RUNNING')

@patch('airflow.providers.apache.spark.hooks.spark_submit.renew_from_kt')
@patch('airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen')
def test_yarn_process_on_kill(self, mock_popen):
def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt):
# Given
mock_popen.return_value.stdout = io.StringIO('stdout')
mock_popen.return_value.stderr = io.StringIO('stderr')
Expand All @@ -679,7 +681,24 @@ def test_yarn_process_on_kill(self, mock_popen):
# Then
self.assertIn(call(['yarn', 'application', '-kill',
'application_1486558679801_1820'],
stderr=-1, stdout=-1),
env=None, stderr=-1, stdout=-1),
mock_popen.mock_calls)
# resetting the mock to test kill with keytab & principal
mock_popen.reset_mock()
# Given
hook = SparkSubmitHook(conn_id='spark_yarn_cluster', keytab='privileged_user.keytab',
principal='user/spark@airflow.org')
hook._process_spark_submit_log(log_lines)
hook.submit()

# When
hook.on_kill()
# Then
expected_env = os.environ.copy()
expected_env["KRB5CCNAME"] = '/tmp/airflow_krb5_ccache'
self.assertIn(call(['yarn', 'application', '-kill',
'application_1486558679801_1820'],
env=expected_env, stderr=-1, stdout=-1),
mock_popen.mock_calls)

def test_standalone_cluster_process_on_kill(self):
Expand Down

0 comments on commit 13a827d

Please sign in to comment.