Skip to content

Commit

Permalink
Add deferrable implementation in HTTPSensor (#36904)
Browse files Browse the repository at this point in the history

Co-authored-by: Wei Lee <weilee.rx@gmail.com>
  • Loading branch information
vatsrahul1001 and Lee-W committed Jan 23, 2024
1 parent 1c958a2 commit 9596bbd
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 2 deletions.
31 changes: 31 additions & 0 deletions airflow/providers/http/sensors/http.py
Expand Up @@ -17,10 +17,13 @@
# under the License.
from __future__ import annotations

from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -78,6 +81,8 @@ def response_check(response, task_instance):
:param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``)
:param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to
``socket.TCP_KEEPINTVL``)
:param deferrable: If waiting for completion, whether to defer the task until done,
default is ``False``
"""

template_fields: Sequence[str] = ("endpoint", "request_params", "headers")
Expand All @@ -97,6 +102,7 @@ def __init__(
tcp_keep_alive_idle: int = 120,
tcp_keep_alive_count: int = 20,
tcp_keep_alive_interval: int = 30,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -114,6 +120,7 @@ def __init__(
self.tcp_keep_alive_idle = tcp_keep_alive_idle
self.tcp_keep_alive_count = tcp_keep_alive_count
self.tcp_keep_alive_interval = tcp_keep_alive_interval
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
from airflow.utils.operator_helpers import determine_kwargs
Expand All @@ -135,9 +142,12 @@ def poke(self, context: Context) -> bool:
headers=self.headers,
extra_options=self.extra_options,
)

if self.response_check:
kwargs = determine_kwargs(self.response_check, [response], context)

return self.response_check(response, **kwargs)

except AirflowException as exc:
if str(exc).startswith(self.response_error_codes_allowlist):
return False
Expand All @@ -148,3 +158,24 @@ def poke(self, context: Context) -> bool:
raise exc

return True

def execute(self, context: Context) -> None:
if not self.deferrable or self.response_check:
super().execute(context=context)
elif not self.poke(context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=HttpSensorTrigger(
endpoint=self.endpoint,
http_conn_id=self.http_conn_id,
data=self.request_params,
headers=self.headers,
method=self.method,
extra_options=self.extra_options,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
self.log.info("%s completed successfully.", self.task_id)
72 changes: 72 additions & 0 deletions airflow/providers/http/triggers/http.py
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
import base64
import pickle
from typing import TYPE_CHECKING, Any, AsyncIterator
Expand All @@ -24,6 +25,7 @@
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict

from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand Down Expand Up @@ -124,3 +126,73 @@ async def _convert_response(client_response: ClientResponse) -> requests.Respons
cookies.set(k, v)
response.cookies = cookies
return response


class HttpSensorTrigger(BaseTrigger):
"""
A trigger that fires when the request to a URL returns a non-404 status code.
:param endpoint: The relative part of the full url
:param http_conn_id: The HTTP Connection ID to run the sensor against
:param method: The HTTP request method to use
:param data: payload to be uploaded or aiohttp parameters
:param headers: The HTTP headers to be added to the GET request
:param extra_options: Additional kwargs to pass when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
:param poke_interval: Time to sleep using asyncio
"""

def __init__(
self,
endpoint: str | None = None,
http_conn_id: str = "http_default",
method: str = "GET",
data: dict[str, Any] | str | None = None,
headers: dict[str, str] | None = None,
extra_options: dict[str, Any] | None = None,
poke_interval: float = 5.0,
):
super().__init__()
self.endpoint = endpoint
self.method = method
self.data = data
self.headers = headers
self.extra_options = extra_options or {}
self.http_conn_id = http_conn_id
self.poke_interval = poke_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes HttpTrigger arguments and classpath."""
return (
"airflow.providers.http.triggers.http.HttpSensorTrigger",
{
"endpoint": self.endpoint,
"data": self.data,
"headers": self.headers,
"extra_options": self.extra_options,
"http_conn_id": self.http_conn_id,
"poke_interval": self.poke_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Makes a series of asynchronous http calls via an http hook."""
hook = self._get_async_hook()
while True:
try:
await hook.run(
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
yield TriggerEvent(True)
except AirflowException as exc:
if str(exc).startswith("404"):
await asyncio.sleep(self.poke_interval)

def _get_async_hook(self) -> HttpAsyncHook:
return HttpAsyncHook(
method=self.method,
http_conn_id=self.http_conn_id,
)
7 changes: 7 additions & 0 deletions docs/apache-airflow-providers-http/operators.rst
Expand Up @@ -37,6 +37,13 @@ Here we are poking until httpbin gives us a response text containing ``httpbin``
:start-after: [START howto_operator_http_http_sensor_check]
:end-before: [END howto_operator_http_http_sensor_check]

This sensor can also be used in deferrable mode

.. exampleinclude:: /../../tests/system/providers/http/example_http.py
:language: python
:start-after: [START howto_operator_http_http_sensor_check_deferrable]
:end-before: [END howto_operator_http_http_sensor_check_deferrable]

.. _howto/operator:HttpOperator:

HttpOperator
Expand Down
54 changes: 53 additions & 1 deletion tests/providers/http/sensors/test_http.py
Expand Up @@ -23,10 +23,11 @@
import pytest
import requests

from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException
from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred
from airflow.models.dag import DAG
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.sensors.http import HttpSensor
from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.utils.timezone import datetime

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -330,3 +331,54 @@ def test_sensor(self):
dag=self.dag,
)
sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)


class TestHttpSensorAsync:
@mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
@mock.patch(
"airflow.providers.http.sensors.http.HttpSensor.poke",
return_value=True,
)
def test_execute_finished_before_deferred(
self,
mock_poke,
mock_defer,
):
"""
Asserts that a task is not deferred when task is already finished
"""

task = HttpSensor(task_id="run_now", endpoint="test-endpoint", deferrable=True)

task.execute({})
assert not mock_defer.called

@mock.patch(
"airflow.providers.http.sensors.http.HttpSensor.poke",
return_value=False,
)
def test_execute_is_deferred(self, mock_poke):
"""
Asserts that a task is deferred and a HttpTrigger will be fired
when the HttpSensor is executed in deferrable mode.
"""

task = HttpSensor(task_id="run_now", endpoint="test-endpoint", deferrable=True)

with pytest.raises(TaskDeferred) as exc:
task.execute({})

assert isinstance(exc.value.trigger, HttpSensorTrigger), "Trigger is not a HttpTrigger"

@mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
@mock.patch("airflow.sensors.base.BaseSensorOperator.execute")
def test_execute_not_defer_when_response_check_is_not_none(self, mock_execute, mock_defer):
task = HttpSensor(
task_id="run_now",
endpoint="test-endpoint",
response_check=lambda response: "httpbin" in response.text,
deferrable=True,
)
task.execute({})
mock_execute.assert_called_once()
mock_defer.assert_not_called()
19 changes: 18 additions & 1 deletion tests/system/providers/http/example_http.py
Expand Up @@ -110,6 +110,17 @@
dag=dag,
)
# [END howto_operator_http_http_sensor_check]
# [START howto_operator_http_http_sensor_check_deferrable]
task_http_sensor_check_async = HttpSensor(
task_id="http_sensor_check_async",
http_conn_id="http_default",
endpoint="",
deferrable=True,
request_params={},
poke_interval=5,
dag=dag,
)
# [END howto_operator_http_http_sensor_check_deferrable]
# [START howto_operator_http_pagination_function]


Expand All @@ -134,7 +145,13 @@ def get_next_page_cursor(response) -> dict | None:
dag=dag,
)
# [END howto_operator_http_pagination_function]
task_http_sensor_check >> task_post_op >> task_get_op >> task_get_op_response_filter
(
task_http_sensor_check
>> task_http_sensor_check_async
>> task_post_op
>> task_get_op
>> task_get_op_response_filter
)
task_get_op_response_filter >> task_put_op >> task_del_op >> task_post_op_formenc
task_post_op_formenc >> task_get_paginated

Expand Down

0 comments on commit 9596bbd

Please sign in to comment.