Skip to content

Commit

Permalink
Consolidate import and usage of pandas (#33480)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Aug 17, 2023
1 parent 47187ce commit 8e88eb8
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 27 deletions.
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Expand Up @@ -29,7 +29,7 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from pandas import DataFrame
import pandas as pd

from airflow.utils.context import Context

Expand Down Expand Up @@ -134,15 +134,15 @@ def __init__(
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")

@staticmethod
def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None:
def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None:
"""
Mutate DataFrame to set dtypes for float columns containing NaN values.
Set dtype of object to str to allow for downstream transformations.
"""
try:
import numpy as np
from pandas import Float64Dtype, Int64Dtype
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

Expand All @@ -163,13 +163,13 @@ def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None:
# The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690
# is merged and released as currently NumPy does not consider None as valid for x/y.
df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload]
df[col] = df[col].astype(Int64Dtype())
df[col] = df[col].astype(pd.Int64Dtype())
elif np.isclose(notna_series, notna_series.astype(int)).all():
# set to float dtype that retains floats and supports NaNs
# The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690
# is merged and released
df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload]
df[col] = df[col].astype(Float64Dtype())
df[col] = df[col].astype(pd.Float64Dtype())

def execute(self, context: Context) -> None:
sql_hook = self._get_hook()
Expand All @@ -192,7 +192,7 @@ def execute(self, context: Context) -> None:
filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
)

def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]:
def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
"""Partition dataframe using pandas groupby() method."""
if not self.groupby_kwargs:
yield "", df
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -31,7 +31,7 @@
from airflow.exceptions import AirflowProviderDeprecationWarning

try:
import pandas
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

Expand Down Expand Up @@ -336,7 +336,7 @@ def test_hql(self, hql: str) -> None:

def load_df(
self,
df: pandas.DataFrame,
df: pd.DataFrame,
table: str,
field_dict: dict[Any, Any] | None = None,
delimiter: str = ",",
Expand All @@ -361,7 +361,7 @@ def load_df(
:param kwargs: passed to self.load_file
"""

def _infer_field_types_from_df(df: pandas.DataFrame) -> dict[Any, Any]:
def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]:
dtype_kind_hive_type = {
"b": "BOOLEAN", # boolean
"i": "BIGINT", # signed integer
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def get_pandas_df( # type: ignore
schema: str = "default",
hive_conf: dict[Any, Any] | None = None,
**kwargs,
) -> pandas.DataFrame:
) -> pd.DataFrame:
"""
Get a pandas dataframe from a Hive query.
Expand All @@ -1056,5 +1056,5 @@ def get_pandas_df( # type: ignore
:return: pandas.DateFrame
"""
res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pandas.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
return df
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -30,6 +30,7 @@
from datetime import datetime, timedelta
from typing import Any, Iterable, Mapping, NoReturn, Sequence, Union, cast

import pandas as pd
from aiohttp import ClientSession as ClientSession
from gcloud.aio.bigquery import Job, Table as Table_async
from google.api_core.page_iterator import HTTPIterator
Expand All @@ -49,7 +50,6 @@
from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference
from google.cloud.exceptions import NotFound
from googleapiclient.discovery import Resource, build
from pandas import DataFrame
from pandas_gbq import read_gbq
from pandas_gbq.gbq import GbqConnector # noqa
from requests import Session
Expand Down Expand Up @@ -244,7 +244,7 @@ def get_pandas_df(
parameters: Iterable | Mapping[str, Any] | None = None,
dialect: str | None = None,
**kwargs,
) -> DataFrame:
) -> pd.DataFrame:
"""Get a Pandas DataFrame for the BigQuery results.
The DbApiHook method must be overridden because Pandas doesn't support
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/presto/hooks/presto.py
Expand Up @@ -158,7 +158,7 @@ def get_first(
raise PrestoException(e)

def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
import pandas
import pandas as pd

cursor = self.get_cursor()
try:
Expand All @@ -168,10 +168,10 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
raise PrestoException(e)
column_descriptions = cursor.description
if data:
df = pandas.DataFrame(data, **kwargs)
df = pd.DataFrame(data, **kwargs)
df.columns = [c[0] for c in column_descriptions]
else:
df = pandas.DataFrame(**kwargs)
df = pd.DataFrame(**kwargs)
return df

def insert_rows(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/slack/transfers/sql_to_slack.py
Expand Up @@ -19,7 +19,7 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence

from pandas import DataFrame
import pandas as pd
from tabulate import tabulate

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -70,7 +70,7 @@ def _get_hook(self) -> DbApiHook:
)
return hook

def _get_query_results(self) -> DataFrame:
def _get_query_results(self) -> pd.DataFrame:
sql_hook = self._get_hook()

self.log.info("Running SQL query: %s", self.sql)
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/trino/hooks/trino.py
Expand Up @@ -178,7 +178,7 @@ def get_first(
def get_pandas_df(
self, sql: str = "", parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
): # type: ignore[override]
import pandas
import pandas as pd

cursor = self.get_cursor()
try:
Expand All @@ -188,10 +188,10 @@ def get_pandas_df(
raise TrinoException(e)
column_descriptions = cursor.description
if data:
df = pandas.DataFrame(data, **kwargs)
df = pd.DataFrame(data, **kwargs)
df.columns = [c[0] for c in column_descriptions]
else:
df = pandas.DataFrame(**kwargs)
df = pd.DataFrame(**kwargs)
return df

def insert_rows(
Expand Down
8 changes: 4 additions & 4 deletions airflow/serialization/serializers/pandas.py
Expand Up @@ -28,19 +28,19 @@
deserializers = serializers

if TYPE_CHECKING:
from pandas import DataFrame
import pandas as pd

from airflow.serialization.serde import U

__version__ = 1


def serialize(o: object) -> tuple[U, str, int, bool]:
import pandas as pd
import pyarrow as pa
from pandas import DataFrame
from pyarrow import parquet as pq

if not isinstance(o, DataFrame):
if not isinstance(o, pd.DataFrame):
return "", "", 0, False

# for now, we *always* serialize into in memory
Expand All @@ -53,7 +53,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True


def deserialize(classname: str, version: int, data: object) -> DataFrame:
def deserialize(classname: str, version: int, data: object) -> pd.DataFrame:
if version > __version__:
raise TypeError(f"serialized {version} of {classname} > {__version__}")

Expand Down
4 changes: 2 additions & 2 deletions tests/serialization/serializers/test_serializers.py
Expand Up @@ -20,7 +20,7 @@
import decimal

import numpy
import pandas
import pandas as pd
import pendulum.tz
import pytest
from pendulum import DateTime
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_params(self):
assert i["x"] == d["x"]

def test_pandas(self):
i = pandas.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
i = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
e = serialize(i)
d = deserialize(e)
assert i.equals(d)

0 comments on commit 8e88eb8

Please sign in to comment.