Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions airflow/contrib/hooks/spark_submit_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:type env_vars: dict
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:type verbose: bool
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit.
:type spark_binary: string
"""
def __init__(self,
conf=None,
Expand All @@ -107,7 +110,8 @@ def __init__(self,
num_executors=None,
application_args=None,
env_vars=None,
verbose=False):
verbose=False,
spark_binary="spark-submit"):
self._conf = conf
self._conn_id = conn_id
self._files = files
Expand All @@ -132,6 +136,7 @@ def __init__(self,
self._submit_sp = None
self._yarn_application_id = None
self._kubernetes_driver_pod = None
self._spark_binary = spark_binary

self._connection = self._resolve_connection()
self._is_yarn = 'yarn' in self._connection['master']
Expand Down Expand Up @@ -161,7 +166,7 @@ def _resolve_connection(self):
'queue': None,
'deploy_mode': None,
'spark_home': None,
'spark_binary': 'spark-submit',
'spark_binary': self._spark_binary,
'namespace': 'default'}

try:
Expand All @@ -178,7 +183,7 @@ def _resolve_connection(self):
conn_data['queue'] = extra.get('queue', None)
conn_data['deploy_mode'] = extra.get('deploy-mode', None)
conn_data['spark_home'] = extra.get('spark-home', None)
conn_data['spark_binary'] = extra.get('spark-binary', 'spark-submit')
conn_data['spark_binary'] = extra.get('spark-binary', "spark-submit")
conn_data['namespace'] = extra.get('namespace', 'default')
except AirflowException:
self.log.debug(
Expand Down
8 changes: 7 additions & 1 deletion airflow/contrib/operators/spark_submit_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class SparkSubmitOperator(BaseOperator):
:type env_vars: dict
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:type verbose: bool
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit.
:type spark_binary: string
"""
template_fields = ('_name', '_application_args', '_packages')
ui_color = WEB_COLORS['LIGHTORANGE']
Expand Down Expand Up @@ -111,6 +114,7 @@ def __init__(self,
application_args=None,
env_vars=None,
verbose=False,
spark_binary="spark-submit",
*args,
**kwargs):
super(SparkSubmitOperator, self).__init__(*args, **kwargs)
Expand All @@ -135,6 +139,7 @@ def __init__(self,
self._application_args = application_args
self._env_vars = env_vars
self._verbose = verbose
self._spark_binary = spark_binary
self._hook = None
self._conn_id = conn_id

Expand Down Expand Up @@ -163,7 +168,8 @@ def execute(self, context):
num_executors=self._num_executors,
application_args=self._application_args,
env_vars=self._env_vars,
verbose=self._verbose
verbose=self._verbose,
spark_binary=self._spark_binary
)
self._hook.submit(self._application)

Expand Down
9 changes: 7 additions & 2 deletions tests/contrib/operators/test_spark_submit_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from datetime import timedelta


DEFAULT_DATE = timezone.datetime(2017, 1, 1)


Expand Down Expand Up @@ -73,15 +74,17 @@ def setUp(self):
self.dag = DAG('test_dag_id', default_args=args)

def test_execute(self):

# Given / When
conn_id = 'spark_default'
operator = SparkSubmitOperator(
task_id='spark_submit_job',
spark_binary="sparky",
dag=self.dag,
**self._config
)

# Then
# Then expected results
expected_dict = {
'conf': {
'parquet.compression': 'SNAPPY'
Expand Down Expand Up @@ -110,7 +113,8 @@ def test_execute(self):
'--start', '{{ macros.ds_add(ds, -1)}}',
'--end', '{{ ds }}',
'--with-spaces', 'args should keep embdedded spaces',
]
],
"spark_binary": "sparky"
}

self.assertEqual(conn_id, operator._conn_id)
Expand All @@ -135,6 +139,7 @@ def test_execute(self):
self.assertEqual(expected_dict['java_class'], operator._java_class)
self.assertEqual(expected_dict['driver_memory'], operator._driver_memory)
self.assertEqual(expected_dict['application_args'], operator._application_args)
self.assertEqual(expected_dict['spark_binary'], operator._spark_binary)

def test_render_template(self):
# Given
Expand Down