Skip to content

Commit

Permalink
Add typing for SqlSensor
Browse files Browse the repository at this point in the history
The SqlSensor was missing typing mainly in its constructor where it would be most valuable from a DAG author's perspective. This PR adds said typing of the sensor, its stubs, and tests (where applicable) as well as a small cleanup of the sensor's docstring.
  • Loading branch information
josh-fell committed May 23, 2024
1 parent 3b1ecbc commit 0f58d17
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 50 deletions.
40 changes: 20 additions & 20 deletions airflow/providers/common/sql/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
# under the License.
from __future__ import annotations

from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class SqlSensor(BaseSensorOperator):
"""Run a sql statement repeatedly until a criteria is met.
"""Run a SQL statement repeatedly until a criteria is met.
This will keep trying until success or failure criteria are met, or if the
first cell is not either ``0``, ``'0'``, ``''``, or ``None``. Optional
Expand All @@ -39,37 +42,34 @@ class SqlSensor(BaseSensorOperator):
in which case it will fail if no rows have been returned.
:param conn_id: The connection to run the sensor against
:param sql: The sql to run. To pass, it needs to return at least one cell
:param sql: The SQL to run. To pass, it needs to return at least one cell
that contains a non-zero / empty string value.
:param parameters: The parameters to render the SQL query with (optional).
:param success: Success criteria for the sensor is a Callable that takes first_cell
:param success: Success criteria for the sensor is a Callable that takes the first_cell's value
as the only argument, and returns a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes first_cell
as the only argument and return a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes the first_cell's value
as the only argument and returns a boolean (optional).
:param fail_on_empty: Explicitly fail on no rows returned.
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
"""

template_fields: Sequence[str] = ("sql", "hook_params", "parameters")
template_ext: Sequence[str] = (
".hql",
".sql",
)
template_ext: Sequence[str] = (".hql", ".sql")
ui_color = "#7c7287"

def __init__(
self,
*,
conn_id,
sql,
parameters=None,
success=None,
failure=None,
fail_on_empty=False,
hook_params=None,
conn_id: str,
sql: str,
parameters: Mapping[str, Any] | None = None,
success: Callable[[Any], bool] | None = None,
failure: Callable[[Any], bool] | None = None,
fail_on_empty: bool = False,
hook_params: Mapping[str, Any] | None = None,
**kwargs,
):
) -> None:
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
Expand All @@ -79,7 +79,7 @@ def __init__(
self.hook_params = hook_params
super().__init__(**kwargs)

def _get_hook(self):
def _get_hook(self) -> DbApiHook:
conn = BaseHook.get_connection(self.conn_id)
hook = conn.get_hook(hook_params=self.hook_params)
if not isinstance(hook, DbApiHook):
Expand All @@ -89,7 +89,7 @@ def _get_hook(self):
)
return hook

def poke(self, context: Any):
def poke(self, context: Context) -> bool:
hook = self._get_hook()

self.log.info("Poking: %s (with parameters %s)", self.sql, self.parameters)
Expand Down
17 changes: 9 additions & 8 deletions airflow/providers/common/sql/sensors/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ from airflow.exceptions import (
from airflow.hooks.base import BaseHook as BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook
from airflow.sensors.base import BaseSensorOperator as BaseSensorOperator
from typing import Any, Sequence
from airflow.utils.context import Context as Context
from typing import Any, Callable, Mapping, Sequence

class SqlSensor(BaseSensorOperator):
template_fields: Sequence[str]
Expand All @@ -55,13 +56,13 @@ class SqlSensor(BaseSensorOperator):
def __init__(
self,
*,
conn_id,
sql,
parameters: Incomplete | None = None,
success: Incomplete | None = None,
failure: Incomplete | None = None,
conn_id: str,
sql: str,
parameters: Mapping[str, Any] | None = None,
success: Callable[[Any], bool] | None = None,
failure: Callable[[Any], bool] | None = None,
fail_on_empty: bool = False,
hook_params: Incomplete | None = None,
hook_params: Mapping[str, Any] | None = None,
**kwargs,
) -> None: ...
def poke(self, context: Any): ...
def poke(self, context: Context) -> bool: ...
44 changes: 22 additions & 22 deletions tests/providers/common/sql/sensors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,25 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[None]]
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [["None"]]
assert op.poke(None)
assert op.poke({})

mock_get_records.return_value = [[0.0]]
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[0]]
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [["0"]]
assert op.poke(None)
assert op.poke({})

mock_get_records.return_value = [["1"]]
assert op.poke(None)
assert op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -138,7 +138,7 @@ def test_sql_sensor_postgres_poke_fail_on_empty(

mock_get_records.return_value = []
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_success(self, mock_hook):
Expand All @@ -150,13 +150,13 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
assert op.poke(None)
assert op.poke({})

mock_get_records.return_value = [["1"]]
assert not op.poke(None)
assert not op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -177,11 +177,11 @@ def test_sql_sensor_postgres_poke_failure(
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -203,14 +203,14 @@ def test_sql_sensor_postgres_poke_failure_success(
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

mock_get_records.return_value = [[2]]
assert op.poke(None)
assert op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -232,11 +232,11 @@ def test_sql_sensor_postgres_poke_failure_success_same(
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
assert not op.poke(None)
assert not op.poke({})

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception):
op.poke(None)
op.poke({})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
Expand All @@ -249,7 +249,7 @@ def test_sql_sensor_postgres_poke_invalid_failure(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=[1],
failure=[1], # type: ignore[arg-type]
soft_fail=soft_fail,
)

Expand All @@ -258,7 +258,7 @@ def test_sql_sensor_postgres_poke_invalid_failure(

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception) as ctx:
op.poke(None)
op.poke({})
assert "self.failure is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.parametrize(
Expand All @@ -272,7 +272,7 @@ def test_sql_sensor_postgres_poke_invalid_success(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
success=[1],
success=[1], # type: ignore[arg-type]
soft_fail=soft_fail,
)

Expand All @@ -281,7 +281,7 @@ def test_sql_sensor_postgres_poke_invalid_success(

mock_get_records.return_value = [[1]]
with pytest.raises(expected_exception) as ctx:
op.poke(None)
op.poke({})
assert "self.success is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.db_test
Expand Down

0 comments on commit 0f58d17

Please sign in to comment.