Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing for SqlSensor #39773

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions airflow/providers/common/sql/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
# under the License.
from __future__ import annotations

from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class SqlSensor(BaseSensorOperator):
"""Run a sql statement repeatedly until a criteria is met.
"""Run a SQL statement repeatedly until a criteria is met.

This will keep trying until success or failure criteria are met, or if the
first cell is not either ``0``, ``'0'``, ``''``, or ``None``. Optional
Expand All @@ -39,37 +42,34 @@ class SqlSensor(BaseSensorOperator):
in which case it will fail if no rows have been returned.

:param conn_id: The connection to run the sensor against
:param sql: The sql to run. To pass, it needs to return at least one cell
:param sql: The SQL to run. To pass, it needs to return at least one cell
that contains a non-zero / empty string value.
:param parameters: The parameters to render the SQL query with (optional).
:param success: Success criteria for the sensor is a Callable that takes first_cell
:param success: Success criteria for the sensor is a Callable that takes the first_cell's value
as the only argument, and returns a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes first_cell
as the only argument and return a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes the first_cell's value
as the only argument and returns a boolean (optional).
:param fail_on_empty: Explicitly fail on no rows returned.
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
"""

template_fields: Sequence[str] = ("sql", "hook_params", "parameters")
template_ext: Sequence[str] = (
".hql",
".sql",
)
template_ext: Sequence[str] = (".hql", ".sql")
ui_color = "#7c7287"

def __init__(
self,
*,
conn_id,
sql,
parameters=None,
success=None,
failure=None,
fail_on_empty=False,
hook_params=None,
conn_id: str,
sql: str,
parameters: Mapping[str, Any] | None = None,
success: Callable[[Any], bool] | None = None,
failure: Callable[[Any], bool] | None = None,
fail_on_empty: bool = False,
hook_params: Mapping[str, Any] | None = None,
**kwargs,
):
) -> None:
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
Expand All @@ -79,7 +79,7 @@ def __init__(
self.hook_params = hook_params
super().__init__(**kwargs)

def _get_hook(self):
def _get_hook(self) -> DbApiHook:
conn = BaseHook.get_connection(self.conn_id)
hook = conn.get_hook(hook_params=self.hook_params)
if not isinstance(hook, DbApiHook):
Expand All @@ -89,7 +89,7 @@ def _get_hook(self):
)
return hook

def poke(self, context: Any):
def poke(self, context: Context) -> bool:
hook = self._get_hook()

self.log.info("Poking: %s (with parameters %s)", self.sql, self.parameters)
Expand Down
17 changes: 9 additions & 8 deletions airflow/providers/common/sql/sensors/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ from airflow.exceptions import (
from airflow.hooks.base import BaseHook as BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook
from airflow.sensors.base import BaseSensorOperator as BaseSensorOperator
from typing import Any, Sequence
from airflow.utils.context import Context as Context
from typing import Any, Callable, Mapping, Sequence

class SqlSensor(BaseSensorOperator):
template_fields: Sequence[str]
Expand All @@ -55,13 +56,13 @@ class SqlSensor(BaseSensorOperator):
def __init__(
self,
*,
conn_id,
sql,
parameters: Incomplete | None = None,
success: Incomplete | None = None,
failure: Incomplete | None = None,
conn_id: str,
sql: str,
parameters: Mapping[str, Any] | None = None,
success: Callable[[Any], bool] | None = None,
failure: Callable[[Any], bool] | None = None,
fail_on_empty: bool = False,
hook_params: Incomplete | None = None,
hook_params: Mapping[str, Any] | None = None,
**kwargs,
) -> None: ...
def poke(self, context: Any): ...
def poke(self, context: Context) -> bool: ...
44 changes: 22 additions & 22 deletions tests/providers/common/sql/sensors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,25 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[None]]
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [["None"]]
assert op.poke(None)
assert op.poke({})

mock_get_records.return_value = [[0.0]]
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[0]]
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [["0"]]
assert op.poke(None)
assert op.poke({})

mock_get_records.return_value = [["1"]]
assert op.poke(None)
assert op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -138,7 +138,7 @@ def test_sql_sensor_postgres_poke_fail_on_empty(

mock_get_records.return_value = []
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_success(self, mock_hook):
Expand All @@ -150,13 +150,13 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
assert op.poke(None)
assert op.poke({})

mock_get_records.return_value = [["1"]]
assert not op.poke(None)
assert not op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -177,11 +177,11 @@ def test_sql_sensor_postgres_poke_failure(
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -203,14 +203,14 @@ def test_sql_sensor_postgres_poke_failure_success(
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

mock_get_records.return_value = [[2]]
assert op.poke(None)
assert op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -232,11 +232,11 @@ def test_sql_sensor_postgres_poke_failure_success_same(
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -249,7 +249,7 @@ def test_sql_sensor_postgres_poke_invalid_failure(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=[1],
failure=[1], # type: ignore[arg-type]
soft_fail=soft_fail,
)

Expand All @@ -258,7 +258,7 @@ def test_sql_sensor_postgres_poke_invalid_failure(

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception) as ctx:
op.poke(None)
op.poke({})
assert "self.failure is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.parametrize(
Expand All @@ -272,7 +272,7 @@ def test_sql_sensor_postgres_poke_invalid_success(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
success=[1],
success=[1], # type: ignore[arg-type]
soft_fail=soft_fail,
)

Expand All @@ -281,7 +281,7 @@ def test_sql_sensor_postgres_poke_invalid_success(

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception) as ctx:
op.poke(None)
op.poke({})
assert "self.success is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.db_test
Expand Down