Skip to content

Commit

Permalink
Avoid importing pandas and numpy in runtime and module level (#33483)
Browse files Browse the repository at this point in the history
* Avoid importing pandas and numpy in runtime, and import them in the methods which use them instead of the module

* fix salesforce tests
  • Loading branch information
hussein-awala committed Aug 18, 2023
1 parent 996d8c5 commit ea8519c
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 26 deletions.
15 changes: 9 additions & 6 deletions airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -26,16 +26,12 @@
import warnings
from collections import OrderedDict
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any, Iterable, Mapping
from typing import TYPE_CHECKING, Any, Iterable, Mapping

from airflow.exceptions import AirflowProviderDeprecationWarning

try:
if TYPE_CHECKING:
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

import csv

Expand Down Expand Up @@ -1055,6 +1051,13 @@ def get_pandas_df( # type: ignore
:return: pandas.DateFrame
"""
try:
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
return df
6 changes: 4 additions & 2 deletions airflow/providers/exasol/hooks/exasol.py
Expand Up @@ -18,14 +18,16 @@
from __future__ import annotations

from contextlib import closing
from typing import Any, Callable, Iterable, Mapping, Sequence, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Sequence, TypeVar, overload

import pandas as pd
import pyexasol
from pyexasol import ExaConnection, ExaStatement

from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results

if TYPE_CHECKING:
import pandas as pd

T = TypeVar("T")


Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -28,9 +28,8 @@
import warnings
from copy import deepcopy
from datetime import datetime, timedelta
from typing import Any, Iterable, Mapping, NoReturn, Sequence, Union, cast
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NoReturn, Sequence, Union, cast

import pandas as pd
from aiohttp import ClientSession as ClientSession
from gcloud.aio.bigquery import Job, Table as Table_async
from google.api_core.page_iterator import HTTPIterator
Expand Down Expand Up @@ -69,6 +68,9 @@
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
import pandas as pd

log = logging.getLogger(__name__)

BigQueryJob = Union[CopyJob, QueryJob, LoadJob, ExtractJob]
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/influxdb/hooks/influxdb.py
Expand Up @@ -24,7 +24,8 @@
"""
from __future__ import annotations

import pandas as pd
from typing import TYPE_CHECKING

from influxdb_client import InfluxDBClient
from influxdb_client.client.flux_table import FluxTable
from influxdb_client.client.write.point import Point
Expand All @@ -33,6 +34,9 @@
from airflow.hooks.base import BaseHook
from airflow.models import Connection

if TYPE_CHECKING:
import pandas as pd


class InfluxDBHook(BaseHook):
"""Interact with InfluxDB.
Expand Down
11 changes: 5 additions & 6 deletions airflow/providers/oracle/hooks/oracle.py
Expand Up @@ -24,12 +24,6 @@
import oracledb

from airflow.exceptions import AirflowProviderDeprecationWarning

try:
import numpy
except ImportError:
numpy = None # type: ignore

from airflow.providers.common.sql.hooks.sql import DbApiHook

PARAM_TYPES = {bool, float, int, str}
Expand Down Expand Up @@ -280,6 +274,11 @@ def insert_rows(
Set 1 to insert each row in each single transaction
:param replace: Whether to replace instead of insert
"""
try:
import numpy
except ImportError:
numpy = None # type: ignore

if target_fields:
target_fields = ", ".join(target_fields)
target_fields = f"({target_fields})"
Expand Down
10 changes: 8 additions & 2 deletions airflow/providers/salesforce/hooks/salesforce.py
Expand Up @@ -26,14 +26,16 @@
import logging
import time
from functools import cached_property
from typing import Any, Iterable
from typing import TYPE_CHECKING, Any, Iterable

import pandas as pd
from requests import Session
from simple_salesforce import Salesforce, api

from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
import pandas as pd

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -240,6 +242,8 @@ def _to_timestamp(cls, column: pd.Series) -> pd.Series:
# between 0 and 10 are turned into timestamps
# if the column cannot be converted,
# just return the original column untouched
import pandas as pd

try:
column = pd.to_datetime(column)
except ValueError:
Expand Down Expand Up @@ -355,6 +359,8 @@ def object_to_df(
to the resulting data that marks when the data was fetched from Salesforce. Default: False
:return: the dataframe.
"""
import pandas as pd

# this line right here will convert all integers to floats
# if there are any None/np.nan values in the column
# that's because None/np.nan cannot exist in an integer column
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/slack/transfers/sql_to_slack.py
Expand Up @@ -19,7 +19,6 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence

import pandas as pd
from tabulate import tabulate

from airflow.exceptions import AirflowException
Expand All @@ -31,6 +30,8 @@
from airflow.providers.slack.utils import parse_filename

if TYPE_CHECKING:
import pandas as pd

from airflow.utils.context import Context


Expand Down
10 changes: 5 additions & 5 deletions tests/providers/salesforce/hooks/test_salesforce.py
Expand Up @@ -338,7 +338,7 @@ def test_write_object_to_file_invalid_format(self):
self.salesforce_hook.write_object_to_file(query_results=[], filename="test", fmt="test")

@patch(
"airflow.providers.salesforce.hooks.salesforce.pd.DataFrame.from_records",
"pandas.DataFrame.from_records",
return_value=pd.DataFrame({"test": [1, 2, 3], "dict": [nan, nan, {"foo": "bar"}]}),
)
def test_write_object_to_file_csv(self, mock_data_frame):
Expand All @@ -360,7 +360,7 @@ def test_write_object_to_file_csv(self, mock_data_frame):
return_value={"fields": [{"name": "field_1", "type": "date"}]},
)
@patch(
"airflow.providers.salesforce.hooks.salesforce.pd.DataFrame.from_records",
"pandas.DataFrame.from_records",
return_value=pd.DataFrame({"test": [1, 2, 3], "field_1": ["2019-01-01", "2019-01-02", "2019-01-03"]}),
)
def test_write_object_to_file_json_with_timestamp_conversion(self, mock_data_frame, mock_describe_object):
Expand All @@ -383,7 +383,7 @@ def test_write_object_to_file_json_with_timestamp_conversion(self, mock_data_fra

@patch("airflow.providers.salesforce.hooks.salesforce.time.time", return_value=1.23)
@patch(
"airflow.providers.salesforce.hooks.salesforce.pd.DataFrame.from_records",
"pandas.DataFrame.from_records",
return_value=pd.DataFrame({"test": [1, 2, 3]}),
)
def test_write_object_to_file_ndjson_with_record_time(self, mock_data_frame, mock_time):
Expand Down Expand Up @@ -416,7 +416,7 @@ def test_write_object_to_file_ndjson_with_record_time(self, mock_data_frame, moc
return_value={"fields": [{"name": "field_1", "type": "date"}]},
)
@patch(
"airflow.providers.salesforce.hooks.salesforce.pd.DataFrame.from_records",
"pandas.DataFrame.from_records",
return_value=pd.DataFrame({"test": [1, 2, 3], "field_1": ["2019-01-01", "2019-01-02", "2019-01-03"]}),
)
def test_object_to_df_with_timestamp_conversion(self, mock_data_frame, mock_describe_object):
Expand All @@ -434,7 +434,7 @@ def test_object_to_df_with_timestamp_conversion(self, mock_data_frame, mock_desc

@patch("airflow.providers.salesforce.hooks.salesforce.time.time", return_value=1.23)
@patch(
"airflow.providers.salesforce.hooks.salesforce.pd.DataFrame.from_records",
"pandas.DataFrame.from_records",
return_value=pd.DataFrame({"test": [1, 2, 3]}),
)
def test_object_to_df_with_record_time(self, mock_data_frame, mock_time):
Expand Down
3 changes: 2 additions & 1 deletion tests/serialization/serializers/test_serializers.py
Expand Up @@ -20,7 +20,6 @@
import decimal

import numpy
import pandas as pd
import pendulum.tz
import pytest
from pendulum import DateTime
Expand Down Expand Up @@ -94,6 +93,8 @@ def test_params(self):
assert i["x"] == d["x"]

def test_pandas(self):
import pandas as pd

i = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
e = serialize(i)
d = deserialize(e)
Expand Down

0 comments on commit ea8519c

Please sign in to comment.