Skip to content

Commit

Permalink
Fix deprecated apache.hive operators arguments in MappedOperator (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Mar 22, 2024
1 parent c893cb3 commit 72c0911
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
10 changes: 6 additions & 4 deletions airflow/providers/apache/hive/operators/hive_stats.py
Expand Up @@ -21,11 +21,12 @@
import warnings
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.presto.hooks.presto import PrestoHook
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -78,17 +79,18 @@ def __init__(
mysql_conn_id: str = "airflow_db",
ds: str = "{{ ds }}",
dttm: str = "{{ logical_date.isoformat() }}",
col_blacklist: list[str] | None | ArgNotSet = NOTSET,
**kwargs: Any,
) -> None:
if "col_blacklist" in kwargs:
if col_blacklist is not NOTSET:
warnings.warn(
f"col_blacklist kwarg passed to {self.__class__.__name__} "
f"(task_id: {kwargs.get('task_id')}) is deprecated, "
f"please rename it to excluded_columns instead",
category=FutureWarning,
category=AirflowProviderDeprecationWarning,
stacklevel=2,
)
excluded_columns = kwargs.pop("col_blacklist")
excluded_columns = col_blacklist # type: ignore[assignment]
super().__init__(**kwargs)
self.table = table
self.partition = partition
Expand Down
42 changes: 41 additions & 1 deletion tests/providers/apache/hive/operators/test_hive_stats.py
Expand Up @@ -23,9 +23,11 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.apache.hive.operators.hive_stats import HiveStatsCollectionOperator
from airflow.providers.presto.hooks.presto import PrestoHook
from airflow.utils import timezone
from airflow.utils.task_instance_session import set_current_task_instance_session
from tests.providers.apache.hive import (
DEFAULT_DATE,
DEFAULT_DATE_DS,
Expand Down Expand Up @@ -370,3 +372,41 @@ def test_runs_for_hive_stats(self, mock_hive_metastore_hook):
"value",
],
)

def test_col_blacklist_deprecation(self):
warn_message = "col_blacklist kwarg passed to.*task_id: fake-task-id.*is deprecated"
with pytest.warns(AirflowProviderDeprecationWarning, match=warn_message):
HiveStatsCollectionOperator(
task_id="fake-task-id",
table="airflow.static_babynames_partitioned",
partition={"ds": DEFAULT_DATE_DS},
col_blacklist=["foo", "bar"],
)

@pytest.mark.db_test
@pytest.mark.parametrize(
"col_blacklist",
[pytest.param(None, id="none"), pytest.param(["foo", "bar"], id="list")],
)
def test_partial_col_blacklist_deprecation(self, col_blacklist, dag_maker, session):
with dag_maker(
dag_id="test_partial_col_blacklist_deprecation",
start_date=timezone.datetime(2024, 1, 1),
session=session,
):
HiveStatsCollectionOperator.partial(
task_id="fake-task-id",
partition={"ds": DEFAULT_DATE_DS},
col_blacklist=col_blacklist,
excluded_columns=["spam", "egg"],
).expand(table=["airflow.table1", "airflow.table2"])

dr = dag_maker.create_dagrun(execution_date=None)
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
warn_message = "col_blacklist kwarg passed to.*task_id: fake-task-id.*is deprecated"
for ti in tis:
with pytest.warns(AirflowProviderDeprecationWarning, match=warn_message):
ti.render_templates()
expected = col_blacklist or []
assert ti.task.excluded_columns == expected

0 comments on commit 72c0911

Please sign in to comment.