Skip to content

Commit

Permalink
Add param proxy user for hive (#36221)
Browse files Browse the repository at this point in the history
  • Loading branch information
romsharon98 committed Dec 19, 2023
1 parent 36cb20a commit 135265d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
6 changes: 5 additions & 1 deletion airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -80,6 +80,7 @@ class HiveCliHook(BaseHook):
This can make monitoring easier.
:param hive_cli_params: Space separated list of hive command parameters to add to the
hive command.
:param proxy_user: Run HQL code as this user.
"""

conn_name_attr = "hive_cli_conn_id"
Expand All @@ -96,6 +97,7 @@ def __init__(
mapred_job_name: str | None = None,
hive_cli_params: str = "",
auth: str | None = None,
proxy_user: str | None = None,
) -> None:
super().__init__()
conn = self.get_connection(hive_cli_conn_id)
Expand All @@ -105,7 +107,6 @@ def __init__(
self.conn = conn
self.run_as = run_as
self.sub_process: Any = None

if mapred_queue_priority:
mapred_queue_priority = mapred_queue_priority.upper()
if mapred_queue_priority not in HIVE_QUEUE_PRIORITIES:
Expand All @@ -116,6 +117,7 @@ def __init__(
self.mapred_queue = mapred_queue or conf.get("hive", "default_hive_mapred_queue")
self.mapred_queue_priority = mapred_queue_priority
self.mapred_job_name = mapred_job_name
self.proxy_user = proxy_user

def _get_proxy_user(self) -> str:
"""Set the proper proxy_user value in case the user overwrite the default."""
Expand All @@ -126,6 +128,8 @@ def _get_proxy_user(self) -> str:
return f"hive.server2.proxy.user={conn.login}"
if proxy_user_value == "owner" and self.run_as:
return f"hive.server2.proxy.user={self.run_as}"
if proxy_user_value == "as_param" and self.proxy_user:
return f"hive.server2.proxy.user={self.proxy_user}"
if proxy_user_value != "": # There is a custom proxy user
return f"hive.server2.proxy.user={proxy_user_value}"
return proxy_user_value # The default proxy user (undefined)
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/apache/hive/operators/hive.py
Expand Up @@ -62,6 +62,7 @@ class HiveOperator(BaseOperator):
This can make monitoring easier.
:param hive_cli_params: parameters passed to hive CLO
:param auth: optional authentication option passed for the Hive connection
:param proxy_user: Run HQL code as this user.
"""

template_fields: Sequence[str] = (
Expand All @@ -72,6 +73,7 @@ class HiveOperator(BaseOperator):
"hiveconfs",
"mapred_job_name",
"mapred_queue_priority",
"proxy_user",
)
template_ext: Sequence[str] = (
".hql",
Expand All @@ -95,6 +97,7 @@ def __init__(
mapred_job_name: str | None = None,
hive_cli_params: str = "",
auth: str | None = None,
proxy_user: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -112,7 +115,7 @@ def __init__(
self.mapred_job_name = mapred_job_name
self.hive_cli_params = hive_cli_params
self.auth = auth

self.proxy_user = proxy_user
job_name_template = conf.get_mandatory_value(
"hive",
"mapred_job_name_template",
Expand All @@ -131,6 +134,7 @@ def hook(self) -> HiveCliHook:
mapred_job_name=self.mapred_job_name,
hive_cli_params=self.hive_cli_params,
auth=self.auth,
proxy_user=self.proxy_user,
)

@deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
Expand Down
Expand Up @@ -71,8 +71,10 @@ Extra (optional)
* ``use_beeline``
Specify as ``True`` if using the Beeline CLI. Default is ``False``.
* ``proxy_user``
Specify a proxy user as an ``owner`` or ``login`` or keep blank if using a
Specify a proxy user as an ``owner`` or ``login`` or ``as_param`` keep blank if using a
custom proxy user.
When using ``owner`` you will want to pass the operator ``run_as_owner=True`` if you don't you will run the hql as user="owner"
When using ``as_param`` you will want to pass the operator ``proxy_user=<some_user>`` if you don't you will run the hql as user="as_param"
* ``principal``
Specify the JDBC Hive principal to be used with Hive Beeline.

Expand Down
23 changes: 20 additions & 3 deletions tests/providers/apache/hive/hooks/test_hive.py
Expand Up @@ -879,18 +879,35 @@ class TestHiveCli:
def setup_method(self):
self.nondefault_schema = "nondefault"

def test_get_proxy_user_value(self):
@pytest.mark.parametrize(
"extra_dejson, correct_proxy_user, run_as, proxy_user",
[
({"proxy_user": "a_user_proxy"}, "hive.server2.proxy.user=a_user_proxy", None, None),
({"proxy_user": "owner"}, "hive.server2.proxy.user=dummy_dag_owner", "dummy_dag_owner", None),
({"proxy_user": "login"}, "hive.server2.proxy.user=admin", None, None),
(
{"proxy_user": "as_param"},
"hive.server2.proxy.user=param_proxy_user",
None,
"param_proxy_user",
),
],
)
def test_get_proxy_user_value(self, extra_dejson, correct_proxy_user, run_as, proxy_user):
hook = MockHiveCliHook()
returner = mock.MagicMock()
returner.extra_dejson = {"proxy_user": "a_user_proxy"}
returner.extra_dejson = extra_dejson
returner.login = "admin"
hook.use_beeline = True
hook.conn = returner
hook.proxy_user = proxy_user
hook.run_as = run_as

# Run
result = hook._prepare_cli_cmd()

# Verify
assert "hive.server2.proxy.user=a_user_proxy" in result[2]
assert correct_proxy_user in result[2]

def test_get_wrong_principal(self):
hook = MockHiveCliHook()
Expand Down

0 comments on commit 135265d

Please sign in to comment.