Skip to content

Commit

Permalink
Add use_krb5ccache option to SparkSubmitOperator (#35331)
Browse files Browse the repository at this point in the history
* Add use_krb5ccach option to sparkSubmitOperator
  • Loading branch information
zeotuan committed Nov 1, 2023
1 parent 83b082d commit 880a85b
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 1 deletion.
4 changes: 4 additions & 0 deletions airflow/providers/apache/spark/hooks/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class SparkJDBCHook(SparkSubmitHook):
(e.g: "name CHAR(64), comments VARCHAR(1024)").
The specified types should be valid spark sql data
types.
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
"""

conn_name_attr = "spark_conn_id"
Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(
upper_bound: str | None = None,
create_table_column_types: str | None = None,
*args: Any,
use_krb5ccache: bool = False,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -153,6 +156,7 @@ def __init__(
self._upper_bound = upper_bound
self._create_table_column_types = create_table_column_types
self._jdbc_connection = self._resolve_jdbc_connection()
self._use_krb5ccache = use_krb5ccache

def _resolve_jdbc_connection(self) -> dict[str, Any]:
conn_data = {"url": "", "schema": "", "conn_prefix": "", "user": "", "password": ""}
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/apache/spark/operators/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class SparkJDBCOperator(SparkSubmitOperator):
(e.g: "name CHAR(64), comments VARCHAR(1024)").
The specified types should be valid spark sql data
types.
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
"""

def __init__(
Expand Down Expand Up @@ -124,6 +127,7 @@ def __init__(
lower_bound: str | None = None,
upper_bound: str | None = None,
create_table_column_types: str | None = None,
use_krb5ccache: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -156,6 +160,7 @@ def __init__(
self._upper_bound = upper_bound
self._create_table_column_types = create_table_column_types
self._hook: SparkJDBCHook | None = None
self._use_krb5ccache = use_krb5ccache

def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
Expand Down Expand Up @@ -198,4 +203,5 @@ def _get_hook(self) -> SparkJDBCHook:
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
create_table_column_types=self._create_table_column_types,
use_krb5ccache=self._use_krb5ccache,
)
5 changes: 5 additions & 0 deletions airflow/providers/apache/spark/operators/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class SparkSubmitOperator(BaseOperator):
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit or spark3-submit.
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -118,6 +120,7 @@ def __init__(
env_vars: dict[str, Any] | None = None,
verbose: bool = False,
spark_binary: str | None = None,
use_krb5ccache: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -148,6 +151,7 @@ def __init__(
self._spark_binary = spark_binary
self._hook: SparkSubmitHook | None = None
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache

def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
Expand Down Expand Up @@ -187,4 +191,5 @@ def _get_hook(self) -> SparkSubmitHook:
env_vars=self._env_vars,
verbose=self._verbose,
spark_binary=self._spark_binary,
use_krb5ccache=self._use_krb5ccache,
)
3 changes: 3 additions & 0 deletions tests/providers/apache/spark/operators/test_spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TestSparkJDBCOperator:
"upper_bound": "20",
"create_table_column_types": "columnMcColumnFace INTEGER(100), name CHAR(64),"
"comments VARCHAR(1024)",
"use_krb5ccache": True,
}

def setup_method(self):
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_execute(self):
"upper_bound": "20",
"create_table_column_types": "columnMcColumnFace INTEGER(100), name CHAR(64),"
"comments VARCHAR(1024)",
"use_krb5ccache": True,
}

assert spark_conn_id == operator._spark_conn_id
Expand Down Expand Up @@ -125,3 +127,4 @@ def test_execute(self):
assert expected_dict["lower_bound"] == operator._lower_bound
assert expected_dict["upper_bound"] == operator._upper_bound
assert expected_dict["create_table_column_types"] == operator._create_table_column_types
assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache
8 changes: 7 additions & 1 deletion tests/providers/apache/spark/operators/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TestSparkSubmitOperator:
"--with-spaces",
"args should keep embedded spaces",
],
"use_krb5ccache": True,
}

def setup_method(self):
Expand All @@ -75,7 +76,10 @@ def test_execute(self):
# Given / When
conn_id = "spark_default"
operator = SparkSubmitOperator(
task_id="spark_submit_job", spark_binary="sparky", dag=self.dag, **self._config
task_id="spark_submit_job",
spark_binary="sparky",
dag=self.dag,
**self._config,
)

# Then expected results
Expand Down Expand Up @@ -115,6 +119,7 @@ def test_execute(self):
"args should keep embedded spaces",
],
"spark_binary": "sparky",
"use_krb5ccache": True,
}

assert conn_id == operator._conn_id
Expand Down Expand Up @@ -142,6 +147,7 @@ def test_execute(self):
assert expected_dict["driver_memory"] == operator._driver_memory
assert expected_dict["application_args"] == operator._application_args
assert expected_dict["spark_binary"] == operator._spark_binary
assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache

@pytest.mark.db_test
def test_render_template(self):
Expand Down

0 comments on commit 880a85b

Please sign in to comment.