From 8e88eb8fa7e1fc12918dcbfcfc8ed28381008d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:57:55 +0000 Subject: [PATCH] Consolidate import and usage of pandas (#33480) --- airflow/providers/amazon/aws/transfers/sql_to_s3.py | 12 ++++++------ airflow/providers/apache/hive/hooks/hive.py | 10 +++++----- airflow/providers/google/cloud/hooks/bigquery.py | 4 ++-- airflow/providers/presto/hooks/presto.py | 6 +++--- airflow/providers/slack/transfers/sql_to_slack.py | 4 ++-- airflow/providers/trino/hooks/trino.py | 6 +++--- airflow/serialization/serializers/pandas.py | 8 ++++---- tests/serialization/serializers/test_serializers.py | 4 ++-- 8 files changed, 27 insertions(+), 27 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py index 92c1906629f64..1302927bfd5fc 100644 --- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -29,7 +29,7 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: - from pandas import DataFrame + import pandas as pd from airflow.utils.context import Context @@ -134,7 +134,7 @@ def __init__( raise AirflowException(f"The argument file_format doesn't support {file_format} value.") @staticmethod - def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None: + def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None: """ Mutate DataFrame to set dtypes for float columns containing NaN values. @@ -142,7 +142,7 @@ def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None: """ try: import numpy as np - from pandas import Float64Dtype, Int64Dtype + import pandas as pd except ImportError as e: from airflow.exceptions import AirflowOptionalProviderFeatureException @@ -163,13 +163,13 @@ def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None: # The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690 # is merged and released as currently NumPy does not consider None as valid for x/y. df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload] - df[col] = df[col].astype(Int64Dtype()) + df[col] = df[col].astype(pd.Int64Dtype()) elif np.isclose(notna_series, notna_series.astype(int)).all(): # set to float dtype that retains floats and supports NaNs # The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690 # is merged and released df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload] - df[col] = df[col].astype(Float64Dtype()) + df[col] = df[col].astype(pd.Float64Dtype()) def execute(self, context: Context) -> None: sql_hook = self._get_hook() @@ -192,7 +192,7 @@ def execute(self, context: Context) -> None: filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace ) - def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]: + def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: """Partition dataframe using pandas groupby() method.""" if not self.groupby_kwargs: yield "", df diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py index 9ab3c8dd020f6..3282d45aca320 100644 --- a/airflow/providers/apache/hive/hooks/hive.py +++ b/airflow/providers/apache/hive/hooks/hive.py @@ -31,7 +31,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning try: - import pandas + import pandas as pd except ImportError as e: from airflow.exceptions import AirflowOptionalProviderFeatureException @@ -336,7 +336,7 @@ def test_hql(self, hql: str) -> None: def load_df( self, - df: pandas.DataFrame, + df: pd.DataFrame, table: str, field_dict: dict[Any, Any] | None = None, delimiter: str = ",", @@ -361,7 +361,7 @@ def load_df( :param kwargs: passed to self.load_file """ - def _infer_field_types_from_df(df: pandas.DataFrame) -> dict[Any, Any]: + def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]: dtype_kind_hive_type = { "b": "BOOLEAN", # boolean "i": "BIGINT", # signed integer @@ -1037,7 +1037,7 @@ def get_pandas_df( # type: ignore schema: str = "default", hive_conf: dict[Any, Any] | None = None, **kwargs, - ) -> pandas.DataFrame: + ) -> pd.DataFrame: """ Get a pandas dataframe from a Hive query. @@ -1056,5 +1056,5 @@ def get_pandas_df( # type: ignore :return: pandas.DateFrame """ res = self.get_results(sql, schema=schema, hive_conf=hive_conf) - df = pandas.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs) + df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs) return df diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 2ac6c2645e20c..44185963b5684 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -30,6 +30,7 @@ from datetime import datetime, timedelta from typing import Any, Iterable, Mapping, NoReturn, Sequence, Union, cast +import pandas as pd from aiohttp import ClientSession as ClientSession from gcloud.aio.bigquery import Job, Table as Table_async from google.api_core.page_iterator import HTTPIterator @@ -49,7 +50,6 @@ from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference from google.cloud.exceptions import NotFound from googleapiclient.discovery import Resource, build -from pandas import DataFrame from pandas_gbq import read_gbq from pandas_gbq.gbq import GbqConnector # noqa from requests import Session @@ -244,7 +244,7 @@ def get_pandas_df( parameters: Iterable | Mapping[str, Any] | None = None, dialect: str | None = None, **kwargs, - ) -> DataFrame: + ) -> pd.DataFrame: """Get a Pandas DataFrame for the BigQuery results. The DbApiHook method must be overridden because Pandas doesn't support diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 7ce9021807893..028deb48ed1e1 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -158,7 +158,7 @@ def get_first( raise PrestoException(e) def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): - import pandas + import pandas as pd cursor = self.get_cursor() try: @@ -168,10 +168,10 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): raise PrestoException(e) column_descriptions = cursor.description if data: - df = pandas.DataFrame(data, **kwargs) + df = pd.DataFrame(data, **kwargs) df.columns = [c[0] for c in column_descriptions] else: - df = pandas.DataFrame(**kwargs) + df = pd.DataFrame(**kwargs) return df def insert_rows( diff --git a/airflow/providers/slack/transfers/sql_to_slack.py b/airflow/providers/slack/transfers/sql_to_slack.py index 97017c80d7e90..ba72e689ee48f 100644 --- a/airflow/providers/slack/transfers/sql_to_slack.py +++ b/airflow/providers/slack/transfers/sql_to_slack.py @@ -19,7 +19,7 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence -from pandas import DataFrame +import pandas as pd from tabulate import tabulate from airflow.exceptions import AirflowException @@ -70,7 +70,7 @@ def _get_hook(self) -> DbApiHook: ) return hook - def _get_query_results(self) -> DataFrame: + def _get_query_results(self) -> pd.DataFrame: sql_hook = self._get_hook() self.log.info("Running SQL query: %s", self.sql) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 5144978dab336..14461b727d998 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -178,7 +178,7 @@ def get_first( def get_pandas_df( self, sql: str = "", parameters: Iterable | Mapping[str, Any] | None = None, **kwargs ): # type: ignore[override] - import pandas + import pandas as pd cursor = self.get_cursor() try: @@ -188,10 +188,10 @@ def get_pandas_df( raise TrinoException(e) column_descriptions = cursor.description if data: - df = pandas.DataFrame(data, **kwargs) + df = pd.DataFrame(data, **kwargs) df.columns = [c[0] for c in column_descriptions] else: - df = pandas.DataFrame(**kwargs) + df = pd.DataFrame(**kwargs) return df def insert_rows( diff --git a/airflow/serialization/serializers/pandas.py b/airflow/serialization/serializers/pandas.py index 0fd9ae04dc9c1..efdc8e11da419 100644 --- a/airflow/serialization/serializers/pandas.py +++ b/airflow/serialization/serializers/pandas.py @@ -28,7 +28,7 @@ deserializers = serializers if TYPE_CHECKING: - from pandas import DataFrame + import pandas as pd from airflow.serialization.serde import U @@ -36,11 +36,11 @@ def serialize(o: object) -> tuple[U, str, int, bool]: + import pandas as pd import pyarrow as pa - from pandas import DataFrame from pyarrow import parquet as pq - if not isinstance(o, DataFrame): + if not isinstance(o, pd.DataFrame): return "", "", 0, False # for now, we *always* serialize into in memory @@ -53,7 +53,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True -def deserialize(classname: str, version: int, data: object) -> DataFrame: +def deserialize(classname: str, version: int, data: object) -> pd.DataFrame: if version > __version__: raise TypeError(f"serialized {version} of {classname} > {__version__}") diff --git a/tests/serialization/serializers/test_serializers.py b/tests/serialization/serializers/test_serializers.py index 79000bea17d7c..e9805d4d777ff 100644 --- a/tests/serialization/serializers/test_serializers.py +++ b/tests/serialization/serializers/test_serializers.py @@ -20,7 +20,7 @@ import decimal import numpy -import pandas +import pandas as pd import pendulum.tz import pytest from pendulum import DateTime @@ -94,7 +94,7 @@ def test_params(self): assert i["x"] == d["x"] def test_pandas(self): - i = pandas.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + i = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) e = serialize(i) d = deserialize(e) assert i.equals(d)