diff --git a/airflow/providers/common/sql/sensors/sql.py b/airflow/providers/common/sql/sensors/sql.py index f7d32c6562e730..7193156ae10e55 100644 --- a/airflow/providers/common/sql/sensors/sql.py +++ b/airflow/providers/common/sql/sensors/sql.py @@ -16,16 +16,19 @@ # under the License. from __future__ import annotations -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence from airflow.exceptions import AirflowException, AirflowSkipException from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.sensors.base import BaseSensorOperator +if TYPE_CHECKING: + from airflow.utils.context import Context + class SqlSensor(BaseSensorOperator): - """Run a sql statement repeatedly until a criteria is met. + """Run a SQL statement repeatedly until a criteria is met. This will keep trying until success or failure criteria are met, or if the first cell is not either ``0``, ``'0'``, ``''``, or ``None``. Optional @@ -39,37 +42,34 @@ class SqlSensor(BaseSensorOperator): in which case it will fail if no rows have been returned. :param conn_id: The connection to run the sensor against - :param sql: The sql to run. To pass, it needs to return at least one cell + :param sql: The SQL to run. To pass, it needs to return at least one cell that contains a non-zero / empty string value. :param parameters: The parameters to render the SQL query with (optional). - :param success: Success criteria for the sensor is a Callable that takes first_cell + :param success: Success criteria for the sensor is a Callable that takes the first_cell's value as the only argument, and returns a boolean (optional). - :param failure: Failure criteria for the sensor is a Callable that takes first_cell - as the only argument and return a boolean (optional). + :param failure: Failure criteria for the sensor is a Callable that takes the first_cell's value + as the only argument and returns a boolean (optional). :param fail_on_empty: Explicitly fail on no rows returned. :param hook_params: Extra config params to be passed to the underlying hook. Should match the desired hook constructor params. """ template_fields: Sequence[str] = ("sql", "hook_params", "parameters") - template_ext: Sequence[str] = ( - ".hql", - ".sql", - ) + template_ext: Sequence[str] = (".hql", ".sql") ui_color = "#7c7287" def __init__( self, *, - conn_id, - sql, - parameters=None, - success=None, - failure=None, - fail_on_empty=False, - hook_params=None, + conn_id: str, + sql: str, + parameters: Mapping[str, Any] | None = None, + success: Callable[[Any], bool] | None = None, + failure: Callable[[Any], bool] | None = None, + fail_on_empty: bool = False, + hook_params: Mapping[str, Any] | None = None, **kwargs, - ): + ) -> None: self.conn_id = conn_id self.sql = sql self.parameters = parameters @@ -79,7 +79,7 @@ def __init__( self.hook_params = hook_params super().__init__(**kwargs) - def _get_hook(self): + def _get_hook(self) -> DbApiHook: conn = BaseHook.get_connection(self.conn_id) hook = conn.get_hook(hook_params=self.hook_params) if not isinstance(hook, DbApiHook): @@ -89,7 +89,7 @@ def _get_hook(self): ) return hook - def poke(self, context: Any): + def poke(self, context: Context) -> bool: hook = self._get_hook() self.log.info("Poking: %s (with parameters %s)", self.sql, self.parameters) diff --git a/airflow/providers/common/sql/sensors/sql.pyi b/airflow/providers/common/sql/sensors/sql.pyi index c8159f0d8562ee..12084e4533f48a 100644 --- a/airflow/providers/common/sql/sensors/sql.pyi +++ b/airflow/providers/common/sql/sensors/sql.pyi @@ -39,7 +39,8 @@ from airflow.exceptions import ( from airflow.hooks.base import BaseHook as BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook from airflow.sensors.base import BaseSensorOperator as BaseSensorOperator -from typing import Any, Sequence +from airflow.utils.context import Context as Context +from typing import Any, Callable, Mapping, Sequence class SqlSensor(BaseSensorOperator): template_fields: Sequence[str] @@ -55,13 +56,13 @@ class SqlSensor(BaseSensorOperator): def __init__( self, *, - conn_id, - sql, - parameters: Incomplete | None = None, - success: Incomplete | None = None, - failure: Incomplete | None = None, + conn_id: str, + sql: str, + parameters: Mapping[str, Any] | None = None, + success: Callable[[Any], bool] | None = None, + failure: Callable[[Any], bool] | None = None, fail_on_empty: bool = False, - hook_params: Incomplete | None = None, + hook_params: Mapping[str, Any] | None = None, **kwargs, ) -> None: ... - def poke(self, context: Any): ... + def poke(self, context: Context) -> bool: ... diff --git a/tests/providers/common/sql/sensors/test_sql.py b/tests/providers/common/sql/sensors/test_sql.py index 39b74272494096..fc2d83895e0658 100644 --- a/tests/providers/common/sql/sensors/test_sql.py +++ b/tests/providers/common/sql/sensors/test_sql.py @@ -98,25 +98,25 @@ def test_sql_sensor_postgres_poke(self, mock_hook): mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [[None]] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [["None"]] - assert op.poke(None) + assert op.poke({}) mock_get_records.return_value = [[0.0]] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [[0]] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [["0"]] - assert op.poke(None) + assert op.poke({}) mock_get_records.return_value = [["1"]] - assert op.poke(None) + assert op.poke({}) @pytest.mark.parametrize( "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) @@ -138,7 +138,7 @@ def test_sql_sensor_postgres_poke_fail_on_empty( mock_get_records.return_value = [] with pytest.raises(expected_exception): - op.poke(None) + op.poke({}) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") def test_sql_sensor_postgres_poke_success(self, mock_hook): @@ -150,13 +150,13 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook): mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [[1]] - assert op.poke(None) + assert op.poke({}) mock_get_records.return_value = [["1"]] - assert not op.poke(None) + assert not op.poke({}) @pytest.mark.parametrize( "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) @@ -177,11 +177,11 @@ def test_sql_sensor_postgres_poke_failure( mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [[1]] with pytest.raises(expected_exception): - op.poke(None) + op.poke({}) @pytest.mark.parametrize( "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) @@ -203,14 +203,14 @@ def test_sql_sensor_postgres_poke_failure_success( mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [[1]] with pytest.raises(expected_exception): - op.poke(None) + op.poke({}) mock_get_records.return_value = [[2]] - assert op.poke(None) + assert op.poke({}) @pytest.mark.parametrize( "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) @@ -232,11 +232,11 @@ def test_sql_sensor_postgres_poke_failure_success_same( mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [] - assert not op.poke(None) + assert not op.poke({}) mock_get_records.return_value = [[1]] with pytest.raises(expected_exception): - op.poke(None) + op.poke({}) @pytest.mark.parametrize( "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) @@ -249,7 +249,7 @@ def test_sql_sensor_postgres_poke_invalid_failure( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", - failure=[1], + failure=[1], # type: ignore[arg-type] soft_fail=soft_fail, ) @@ -258,7 +258,7 @@ def test_sql_sensor_postgres_poke_invalid_failure( mock_get_records.return_value = [[1]] with pytest.raises(expected_exception) as ctx: - op.poke(None) + op.poke({}) assert "self.failure is present, but not callable -> [1]" == str(ctx.value) @pytest.mark.parametrize( @@ -272,7 +272,7 @@ def test_sql_sensor_postgres_poke_invalid_success( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", - success=[1], + success=[1], # type: ignore[arg-type] soft_fail=soft_fail, ) @@ -281,7 +281,7 @@ def test_sql_sensor_postgres_poke_invalid_success( mock_get_records.return_value = [[1]] with pytest.raises(expected_exception) as ctx: - op.poke(None) + op.poke({}) assert "self.success is present, but not callable -> [1]" == str(ctx.value) @pytest.mark.db_test