Skip to content

Commit

Permalink
Consolidate hook management in HiveOperator (#34430)
Browse files Browse the repository at this point in the history
* Consolidate hook management in HiveOperator

* use AirflowProviderDeprecationWarning
  • Loading branch information
hussein-awala committed Sep 18, 2023
1 parent 76628ae commit 169ce92
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
20 changes: 11 additions & 9 deletions airflow/providers/apache/hive/operators/hive.py
Expand Up @@ -19,9 +19,13 @@

import os
import re
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from deprecated.classic import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
from airflow.utils import operator_helpers
Expand Down Expand Up @@ -116,13 +120,8 @@ def __init__(
)
self.mapred_job_name_template: str = job_name_template

# assigned lazily - just for consistency we can create the attribute with a
# `None` initial value, later it will be populated by the execute method.
# This also makes `on_kill` implementation consistent since it assumes `self.hook`
# is defined.
self.hook: HiveCliHook | None = None

def get_hook(self) -> HiveCliHook:
@cached_property
def hook(self) -> HiveCliHook:
"""Get Hive cli hook."""
return HiveCliHook(
hive_cli_conn_id=self.hive_cli_conn_id,
Expand All @@ -134,6 +133,11 @@ def get_hook(self) -> HiveCliHook:
auth=self.auth,
)

@deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
def get_hook(self) -> HiveCliHook:
"""Get Hive cli hook."""
return self.hook

def prepare_template(self) -> None:
if self.hiveconf_jinja_translate:
self.hql = re.sub(r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql)
Expand All @@ -142,7 +146,6 @@ def prepare_template(self) -> None:

def execute(self, context: Context) -> None:
self.log.info("Executing: %s", self.hql)
self.hook = self.get_hook()

# set the mapred_job_name if it's not set with dag, task, execution time info
if not self.mapred_job_name:
Expand All @@ -169,7 +172,6 @@ def dry_run(self) -> None:
# existing env vars from impacting behavior.
self.clear_airflow_vars()

self.hook = self.get_hook()
self.hook.test_hql(hql=self.hql)

def on_kill(self) -> None:
Expand Down
10 changes: 4 additions & 6 deletions tests/providers/apache/hive/operators/test_hive.py
Expand Up @@ -41,7 +41,7 @@ def test_hive_airflow_default_config_queue(self):

# just check that the correct default value in test_default.cfg is used
test_config_hive_mapred_queue = conf.get("hive", "default_hive_mapred_queue")
assert op.get_hook().mapred_queue == test_config_hive_mapred_queue
assert op.hook.mapred_queue == test_config_hive_mapred_queue

def test_hive_airflow_default_config_queue_override(self):
specific_mapred_queue = "default"
Expand All @@ -54,7 +54,7 @@ def test_hive_airflow_default_config_queue_override(self):
dag=self.dag,
)

assert op.get_hook().mapred_queue == specific_mapred_queue
assert op.hook.mapred_queue == specific_mapred_queue


class HiveOperatorTest(TestHiveEnvironment):
Expand All @@ -75,10 +75,8 @@ def test_hiveconf(self):
op.prepare_template()
assert op.hql == "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});"

@mock.patch("airflow.providers.apache.hive.operators.hive.HiveOperator.get_hook")
def test_mapred_job_name(self, mock_get_hook):
mock_hook = mock.MagicMock()
mock_get_hook.return_value = mock_hook
@mock.patch("airflow.providers.apache.hive.operators.hive.HiveOperator.hook", mock.MagicMock())
def test_mapred_job_name(self, mock_hook):
op = HiveOperator(task_id="test_mapred_job_name", hql=self.hql, dag=self.dag)

fake_run_id = "test_mapred_job_name"
Expand Down

0 comments on commit 169ce92

Please sign in to comment.