diff --git a/airflow/example_dags/plugins/workday.py b/airflow/example_dags/plugins/workday.py index 2c3299c960bbb..848ef1405193d 100644 --- a/airflow/example_dags/plugins/workday.py +++ b/airflow/example_dags/plugins/workday.py @@ -38,7 +38,7 @@ holiday_calendar = USFederalHolidayCalendar() except ImportError: log.warning("Could not import pandas. Holidays will not be considered.") - holiday_calendar = None + holiday_calendar = None # type: ignore[assignment] class AfterWorkdayTimetable(Timetable): diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py index b71554b2a32ed..c00784ad4a71f 100644 --- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -20,7 +20,7 @@ import enum from collections import namedtuple from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, cast from typing_extensions import Literal @@ -105,7 +105,7 @@ def __init__( s3_key: str, sql_conn_id: str, sql_hook_params: dict | None = None, - parameters: None | Mapping | Iterable = None, + parameters: None | Mapping[str, Any] | list | tuple = None, replace: bool = False, aws_conn_id: str = "aws_default", verify: bool | str | None = None, @@ -158,7 +158,7 @@ def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None: if "float" in df[col].dtype.name and df[col].hasnans: # inspect values to determine if dtype of non-null values is int or float - notna_series = df[col].dropna().values + notna_series: Any = df[col].dropna().values if np.equal(notna_series, notna_series.astype(int)).all(): # set to dtype that retains integers and supports NaNs # The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690 @@ -196,10 +196,12 @@ def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataF """Partition dataframe using pandas groupby() method.""" if not self.groupby_kwargs: yield "", df - else: - grouped_df = df.groupby(**self.groupby_kwargs) - for group_label in grouped_df.groups: - yield group_label, grouped_df.get_group(group_label).reset_index(drop=True) + return + for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups: + yield ( + cast(str, group_label), + cast("pd.DataFrame", grouped_df.get_group(group_label).reset_index(drop=True)), + ) def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index adbb95d784d64..950ac7ee19e8b 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -22,6 +22,7 @@ TYPE_CHECKING, Any, Callable, + Generator, Iterable, Mapping, Protocol, @@ -41,6 +42,8 @@ from airflow.version import version if TYPE_CHECKING: + from pandas import DataFrame + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo @@ -198,7 +201,12 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): engine_kwargs = {} return create_engine(self.get_uri(), **engine_kwargs) - def get_pandas_df(self, sql, parameters: Iterable | Mapping[str, Any] | None = None, **kwargs): + def get_pandas_df( + self, + sql, + parameters: list | tuple | Mapping[str, Any] | None = None, + **kwargs, + ) -> DataFrame: """ Execute the sql and returns a pandas dataframe. @@ -218,14 +226,19 @@ def get_pandas_df(self, sql, parameters: Iterable | Mapping[str, Any] | None = N return psql.read_sql(sql, con=conn, params=parameters, **kwargs) def get_pandas_df_by_chunks( - self, sql, parameters: Iterable | Mapping[str, Any] | None = None, *, chunksize: int | None, **kwargs - ): + self, + sql, + parameters: list | tuple | Mapping[str, Any] | None = None, + *, + chunksize: int, + **kwargs, + ) -> Generator[DataFrame, None, None]: """ Execute the sql and return a generator. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :param parameters: The parameters to render the SQL query with - :param chunksize: number of rows to include in each chunk + :param chunksize: number of rows to include in each chunk :param kwargs: (optional) passed into pandas.io.sql.read_sql method """ try: diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 87dd0c68bca81..4cfb2b76019e4 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -171,7 +171,7 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): column_descriptions = cursor.description if data: df = pd.DataFrame(data, **kwargs) - df.columns = [c[0] for c in column_descriptions] + df.rename(columns={n: c[0] for n, c in zip(df.columns, column_descriptions)}, inplace=True) else: df = pd.DataFrame(**kwargs) return df diff --git a/airflow/providers/salesforce/hooks/salesforce.py b/airflow/providers/salesforce/hooks/salesforce.py index 5bdbb6eaa815f..b8dc1af10f9d6 100644 --- a/airflow/providers/salesforce/hooks/salesforce.py +++ b/airflow/providers/salesforce/hooks/salesforce.py @@ -367,8 +367,7 @@ def object_to_df( # that's because None/np.nan cannot exist in an integer column # we should write all of our timestamps as FLOATS in our final schema df = pd.DataFrame.from_records(query_results, exclude=["attributes"]) - - df.columns = [column.lower() for column in df.columns] + df.rename(columns=str.lower, inplace=True) # convert columns with datetime strings to datetimes # not all strings will be datetimes, so we ignore any errors that occur diff --git a/airflow/providers/slack/transfers/base_sql_to_slack.py b/airflow/providers/slack/transfers/base_sql_to_slack.py index cdc5b4cc7a562..70e48e72bf103 100644 --- a/airflow/providers/slack/transfers/base_sql_to_slack.py +++ b/airflow/providers/slack/transfers/base_sql_to_slack.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Mapping +from typing import TYPE_CHECKING, Any, Mapping from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -50,7 +50,7 @@ def __init__( sql: str, sql_conn_id: str, sql_hook_params: dict | None = None, - parameters: Iterable | Mapping[str, Any] | None = None, + parameters: list | tuple | Mapping[str, Any] | None = None, slack_proxy: str | None = None, slack_timeout: int | None = None, slack_retry_handlers: list[RetryHandler] | None = None, diff --git a/airflow/providers/slack/transfers/sql_to_slack.py b/airflow/providers/slack/transfers/sql_to_slack.py index 2bc249c7c1ecb..6f2ab24f853ae 100644 --- a/airflow/providers/slack/transfers/sql_to_slack.py +++ b/airflow/providers/slack/transfers/sql_to_slack.py @@ -18,7 +18,7 @@ import warnings from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Mapping, Sequence from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.slack.hooks.slack import SlackHook @@ -74,7 +74,7 @@ def __init__( sql: str, sql_conn_id: str, sql_hook_params: dict | None = None, - parameters: Iterable | Mapping[str, Any] | None = None, + parameters: list | tuple | Mapping[str, Any] | None = None, slack_conn_id: str = SlackHook.default_conn_name, slack_filename: str, slack_channels: str | Sequence[str] | None = None, diff --git a/airflow/providers/slack/transfers/sql_to_slack_webhook.py b/airflow/providers/slack/transfers/sql_to_slack_webhook.py index 31700ed8b01e0..0293b684aec51 100644 --- a/airflow/providers/slack/transfers/sql_to_slack_webhook.py +++ b/airflow/providers/slack/transfers/sql_to_slack_webhook.py @@ -85,7 +85,7 @@ def __init__( slack_channel: str | None = None, slack_message: str, results_df_name: str = "results_df", - parameters: Iterable | Mapping[str, Any] | None = None, + parameters: list | tuple | Mapping[str, Any] | None = None, **kwargs, ) -> None: if slack_conn_id := kwargs.pop("slack_conn_id", None): diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 1d1bbd1cd2729..03195fe4524e7 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -189,7 +189,7 @@ def get_pandas_df(self, sql: str = "", parameters: Iterable | Mapping[str, Any] column_descriptions = cursor.description if data: df = pd.DataFrame(data, **kwargs) - df.columns = [c[0] for c in column_descriptions] + df.rename(columns={n: c[0] for n, c in zip(df.columns, column_descriptions)}, inplace=True) else: df = pd.DataFrame(**kwargs) return df diff --git a/tests/plugins/workday.py b/tests/plugins/workday.py index 20363a69e7a4b..72202fcc5c75b 100644 --- a/tests/plugins/workday.py +++ b/tests/plugins/workday.py @@ -34,7 +34,7 @@ holiday_calendar = USFederalHolidayCalendar() except ImportError: log.warning("Could not import pandas. Holidays will not be considered.") - holiday_calendar = None + holiday_calendar = None # type: ignore[assignment] class AfterWorkdayTimetable(Timetable):