Skip to content

Commit

Permalink
Fix update-common-sql-api-stubs pre-commit check (#38915)
Browse files Browse the repository at this point in the history
(cherry picked from commit 4f169bd)
  • Loading branch information
Taragolis authored and ephraimbuddy committed Apr 16, 2024
1 parent a1a565e commit eda7481
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 87 deletions.
85 changes: 56 additions & 29 deletions airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,71 +32,98 @@ Definition of the public interface for airflow.providers.common.sql.hooks.sql
isort:skip_file
"""
from _typeshed import Incomplete
from airflow.hooks.base import BaseHook as BaseForDbApiHook
from typing import Any, Callable, Iterable, Mapping, Sequence, Union
from typing_extensions import Protocol
from airflow.exceptions import (
AirflowException as AirflowException,
AirflowOptionalProviderFeatureException as AirflowOptionalProviderFeatureException,
AirflowProviderDeprecationWarning as AirflowProviderDeprecationWarning,
)
from airflow.hooks.base import BaseHook as BaseHook
from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo
from pandas import DataFrame as DataFrame
from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, Sequence, TypeVar, overload

def return_single_query_results(
sql: Union[str, Iterable[str]], return_last: bool, split_statements: bool
): ...
def fetch_all_handler(cursor) -> Union[list[tuple], None]: ...
def fetch_one_handler(cursor) -> Union[list[tuple], None]: ...
T = TypeVar("T")
SQL_PLACEHOLDERS: Incomplete

def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool): ...
def fetch_all_handler(cursor) -> list[tuple] | None: ...
def fetch_one_handler(cursor) -> list[tuple] | None: ...

class ConnectorProtocol(Protocol):
def connect(self, host: str, port: int, username: str, schema: str) -> Any: ...

class DbApiHook(BaseForDbApiHook):
class DbApiHook(BaseHook):
conn_name_attr: str
default_conn_name: str
supports_autocommit: bool
connector: Union[ConnectorProtocol, None]
placeholder: str
connector: ConnectorProtocol | None
log_sql: Incomplete
descriptions: Incomplete
_placeholder: str
def __init__(self, *args, schema: Union[str, None] = ..., log_sql: bool = ..., **kwargs) -> None: ...
def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs) -> None: ...
@property
def placeholder(self): ...
def get_conn(self): ...
def get_uri(self) -> str: ...
def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = ...): ...
def get_pandas_df(self, sql, parameters: Incomplete | None = ..., **kwargs): ...
def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = None): ...
def get_pandas_df(
self, sql, parameters: list | tuple | Mapping[str, Any] | None = None, **kwargs
) -> DataFrame: ...
def get_pandas_df_by_chunks(
self, sql, parameters: Incomplete | None = ..., *, chunksize, **kwargs
) -> None: ...
self, sql, parameters: list | tuple | Mapping[str, Any] | None = None, *, chunksize: int, **kwargs
) -> Generator[DataFrame, None, None]: ...
def get_records(
self, sql: Union[str, list[str]], parameters: Union[Iterable, Mapping, None] = ...
self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None
) -> Any: ...
def get_first(
self, sql: Union[str, list[str]], parameters: Union[Iterable, Mapping, None] = ...
self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None
) -> Any: ...
@staticmethod
def strip_sql_string(sql: str) -> str: ...
@staticmethod
def split_sql_string(sql: str) -> list[str]: ...
@property
def last_description(self) -> Union[Sequence[Sequence], None]: ...
def last_description(self) -> Sequence[Sequence] | None: ...
@overload
def run(
self,
sql: Union[str, Iterable[str]],
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Union[Iterable, Mapping, None] = ...,
handler: Union[Callable, None] = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> Union[Any, list[Any], None]: ...
) -> None: ...
@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...
def set_autocommit(self, conn, autocommit) -> None: ...
def get_autocommit(self, conn) -> bool: ...
def get_cursor(self): ...
def insert_rows(
self,
table,
rows,
target_fields: Incomplete | None = ...,
commit_every: int = ...,
replace: bool = ...,
target_fields: Incomplete | None = None,
commit_every: int = 1000,
replace: bool = False,
*,
executemany: bool = ...,
executemany: bool = False,
**kwargs,
) -> None: ...
): ...
def bulk_dump(self, table, tmp_file) -> None: ...
def bulk_load(self, table, tmp_file) -> None: ...
def test_connection(self): ...
def get_openlineage_database_info(self, connection) -> DatabaseInfo | None: ...
def get_openlineage_database_dialect(self, connection) -> str: ...
def get_openlineage_default_schema(self) -> str | None: ...
def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None: ...
@staticmethod
def get_openlineage_authority_part(connection, default_port: int | None = None) -> str: ...
100 changes: 59 additions & 41 deletions airflow/providers/common/sql/operators/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,40 @@
Definition of the public interface for airflow.providers.common.sql.operators.sql
isort:skip_file
"""
from _typeshed import Incomplete # noqa: F401
from airflow.models import BaseOperator, SkipMixin
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.context import Context
from typing import Any, Callable, Iterable, Mapping, Sequence, SupportsAbs, Union
from _typeshed import Incomplete
from airflow.exceptions import (
AirflowException as AirflowException,
AirflowFailException as AirflowFailException,
)
from airflow.hooks.base import BaseHook as BaseHook
from airflow.models import BaseOperator as BaseOperator, SkipMixin as SkipMixin
from airflow.providers.common.sql.hooks.sql import (
DbApiHook as DbApiHook,
fetch_all_handler as fetch_all_handler,
return_single_query_results as return_single_query_results,
)
from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage
from airflow.utils.context import Context as Context
from airflow.utils.helpers import merge_dicts as merge_dicts
from functools import cached_property as cached_property
from typing import Any, Callable, Iterable, Mapping, Sequence, SupportsAbs

def _parse_boolean(val: str) -> str | bool: ...
def parse_boolean(val: str) -> str | bool: ...

class BaseSQLOperator(BaseOperator):
conn_id_field: str
conn_id: Incomplete
database: Incomplete
hook_params: Incomplete
retry_on_failure: Incomplete
def __init__(
self,
*,
conn_id: Union[str, None] = ...,
database: Union[str, None] = ...,
hook_params: Union[dict, None] = ...,
retry_on_failure: bool = ...,
conn_id: str | None = None,
database: str | None = None,
hook_params: dict | None = None,
retry_on_failure: bool = True,
**kwargs,
) -> None: ...
def get_db_hook(self) -> DbApiHook: ...
Expand All @@ -72,20 +85,24 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
def __init__(
self,
*,
sql: Union[str, list[str]],
autocommit: bool = ...,
parameters: Union[Mapping, Iterable, None] = ...,
sql: str | list[str],
autocommit: bool = False,
parameters: Mapping | Iterable | None = None,
handler: Callable[[Any], Any] = ...,
split_statements: Union[bool, None] = ...,
return_last: bool = ...,
show_return_value_in_logs: bool = ...,
conn_id: str | None = None,
database: str | None = None,
split_statements: bool | None = None,
return_last: bool = True,
show_return_value_in_logs: bool = False,
**kwargs,
) -> None: ...
def execute(self, context): ...
def prepare_template(self) -> None: ...
def get_openlineage_facets_on_start(self) -> OperatorLineage | None: ...
def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None: ...

class SQLColumnCheckOperator(BaseSQLOperator):
template_fields: Incomplete
template_fields: Sequence[str]
template_fields_renderers: Incomplete
sql_check_template: str
column_checks: Incomplete
Expand All @@ -99,16 +116,16 @@ class SQLColumnCheckOperator(BaseSQLOperator):
*,
table: str,
column_mapping: dict[str, dict[str, Any]],
partition_clause: Union[str, None] = ...,
conn_id: Union[str, None] = ...,
database: Union[str, None] = ...,
accept_none: bool = ...,
partition_clause: str | None = None,
conn_id: str | None = None,
database: str | None = None,
accept_none: bool = True,
**kwargs,
) -> None: ...
def execute(self, context: Context): ...

class SQLTableCheckOperator(BaseSQLOperator):
template_fields: Incomplete
template_fields: Sequence[str]
template_fields_renderers: Incomplete
sql_check_template: str
table: Incomplete
Expand All @@ -120,9 +137,9 @@ class SQLTableCheckOperator(BaseSQLOperator):
*,
table: str,
checks: dict[str, dict[str, Any]],
partition_clause: Union[str, None] = ...,
conn_id: Union[str, None] = ...,
database: Union[str, None] = ...,
partition_clause: str | None = None,
conn_id: str | None = None,
database: str | None = None,
**kwargs,
) -> None: ...
def execute(self, context: Context): ...
Expand All @@ -138,9 +155,9 @@ class SQLCheckOperator(BaseSQLOperator):
self,
*,
sql: str,
conn_id: Union[str, None] = ...,
database: Union[str, None] = ...,
parameters: Union[Iterable, Mapping, None] = ...,
conn_id: str | None = None,
database: str | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None: ...
def execute(self, context: Context): ...
Expand All @@ -160,11 +177,12 @@ class SQLValueCheckOperator(BaseSQLOperator):
*,
sql: str,
pass_value: Any,
tolerance: Any = ...,
conn_id: Union[str, None] = ...,
database: Union[str, None] = ...,
tolerance: Any = None,
conn_id: str | None = None,
database: str | None = None,
**kwargs,
) -> None: ...
def check_value(self, records) -> None: ...
def execute(self, context: Context): ...

class SQLIntervalCheckOperator(BaseSQLOperator):
Expand All @@ -188,12 +206,12 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
*,
table: str,
metrics_thresholds: dict[str, int],
date_filter_column: Union[str, None] = ...,
days_back: SupportsAbs[int] = ...,
ratio_formula: Union[str, None] = ...,
ignore_zero: bool = ...,
conn_id: Union[str, None] = ...,
database: Union[str, None] = ...,
date_filter_column: str | None = "ds",
days_back: SupportsAbs[int] = -7,
ratio_formula: str | None = "max_over_min",
ignore_zero: bool = True,
conn_id: str | None = None,
database: str | None = None,
**kwargs,
) -> None: ...
def execute(self, context: Context): ...
Expand All @@ -211,8 +229,8 @@ class SQLThresholdCheckOperator(BaseSQLOperator):
sql: str,
min_threshold: Any,
max_threshold: Any,
conn_id: Union[str, None] = ...,
database: Union[str, None] = ...,
conn_id: str | None = None,
database: str | None = None,
**kwargs,
) -> None: ...
def execute(self, context: Context): ...
Expand All @@ -234,9 +252,9 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
sql: str,
follow_task_ids_if_true: list[str],
follow_task_ids_if_false: list[str],
conn_id: str = ...,
database: Union[str, None] = ...,
parameters: Union[Iterable, Mapping, None] = ...,
conn_id: str = "default_conn_id",
database: str | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None: ...
def execute(self, context: Context): ...
18 changes: 12 additions & 6 deletions airflow/providers/common/sql/sensors/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ Definition of the public interface for airflow.providers.common.sql.sensors.sql
isort:skip_file
"""
from _typeshed import Incomplete
from airflow.sensors.base import BaseSensorOperator
from airflow.exceptions import (
AirflowException as AirflowException,
AirflowSkipException as AirflowSkipException,
)
from airflow.hooks.base import BaseHook as BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook
from airflow.sensors.base import BaseSensorOperator as BaseSensorOperator
from typing import Any, Sequence

class SqlSensor(BaseSensorOperator):
Expand All @@ -51,11 +57,11 @@ class SqlSensor(BaseSensorOperator):
*,
conn_id,
sql,
parameters: Incomplete | None = ...,
success: Incomplete | None = ...,
failure: Incomplete | None = ...,
fail_on_empty: bool = ...,
hook_params: Incomplete | None = ...,
parameters: Incomplete | None = None,
success: Incomplete | None = None,
failure: Incomplete | None = None,
fail_on_empty: bool = False,
hook_params: Incomplete | None = None,
**kwargs,
) -> None: ...
def poke(self, context: Any): ...
Loading

0 comments on commit eda7481

Please sign in to comment.