From 70b3bd3fb960e8b052f31b4acb59961357548e3a Mon Sep 17 00:00:00 2001 From: Joffrey Bienvenu Date: Fri, 3 Nov 2023 18:14:49 +0100 Subject: [PATCH] Add pagination to `HttpOperator` and make it more modular (#34669) * feat: Make SimpleHttpOperator extendable * feat: Implement ExtendedHttpOperator * feat: Add sync and async tests for `pagination_function` * feat: Add example and documentation * fix: Add missing return statements * fix: typo in class docstring Co-authored-by: Jens Scheffler <95105677+jens-scheffler-bosch@users.noreply.github.com> * fix: make use of hook property in DiscordWebhookHook * fix: rename to PaginatedHttpOperator * fix: Correctly route reference link to PaginatedHttpOperator docs * fix: makes SimpleHttpOperator types customizable for mypy * fix: add missing dashes in docs + add missing reference to paginated operator * fix: add missing reference to `PaginatedHttpOperator` * feat: implement hook retrieval based on connection id * feat: Merge PaginatedOperator to SimpleHttpOperator * fix: Removes mention of PaginatedHttpOperator * fix: Apply static checks code quality * fix: Reformulate docs * feat: Deprecate `SimpleHttpOperator` and rename to `HttpOperator` * fix: Remove 'HttpOperator' from `__deprecated_classes` --------- Co-authored-by: Jens Scheffler <95105677+jens-scheffler-bosch@users.noreply.github.com> --- .../discord/operators/discord_webhook.py | 15 +- airflow/providers/http/operators/http.py | 223 +++++++++++++++--- .../connections/databricks.rst | 2 +- .../operators.rst | 37 ++- .../core-concepts/operators.rst | 4 +- docs/apache-airflow/tutorial/taskflow.rst | 4 +- tests/providers/http/operators/test_http.py | 97 +++++++- tests/providers/http/sensors/test_http.py | 6 +- tests/system/providers/http/example_http.py | 38 ++- 9 files changed, 356 insertions(+), 70 deletions(-) diff --git a/airflow/providers/discord/operators/discord_webhook.py b/airflow/providers/discord/operators/discord_webhook.py index e4c38f8356e2b..0650b19b26cd8 100644 --- a/airflow/providers/discord/operators/discord_webhook.py +++ b/airflow/providers/discord/operators/discord_webhook.py @@ -21,13 +21,13 @@ from airflow.exceptions import AirflowException from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook -from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.providers.http.operators.http import HttpOperator if TYPE_CHECKING: from airflow.utils.context import Context -class DiscordWebhookOperator(SimpleHttpOperator): +class DiscordWebhookOperator(HttpOperator): """ This operator allows you to post messages to Discord using incoming webhooks. @@ -77,11 +77,10 @@ def __init__( self.avatar_url = avatar_url self.tts = tts self.proxy = proxy - self.hook: DiscordWebhookHook | None = None - def execute(self, context: Context) -> None: - """Call the DiscordWebhookHook to post message.""" - self.hook = DiscordWebhookHook( + @property + def hook(self) -> DiscordWebhookHook: + hook = DiscordWebhookHook( self.http_conn_id, self.webhook_endpoint, self.message, @@ -90,4 +89,8 @@ def execute(self, context: Context) -> None: self.tts, self.proxy, ) + return hook + + def execute(self, context: Context) -> None: + """Call the DiscordWebhookHook to post a message.""" self.hook.execute() diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index f4010e22d04af..216de77e311ed 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -19,28 +19,32 @@ import base64 import pickle +import warnings from typing import TYPE_CHECKING, Any, Callable, Sequence +from requests import Response + from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.hooks.base import BaseHook from airflow.models import BaseOperator -from airflow.providers.http.hooks.http import HttpHook from airflow.providers.http.triggers.http import HttpTrigger +from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: - from requests import Response from requests.auth import AuthBase + from airflow.providers.http.hooks.http import HttpHook from airflow.utils.context import Context -class SimpleHttpOperator(BaseOperator): +class HttpOperator(BaseOperator): """ Calls an endpoint on an HTTP system to execute an action. .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:SimpleHttpOperator` + :ref:`howto/operator:HttpOperator` :param http_conn_id: The :ref:`http connection` to run the operator against @@ -49,14 +53,26 @@ class SimpleHttpOperator(BaseOperator): :param data: The data to pass. POST-data in POST/PUT and params in the URL for a GET request. (templated) :param headers: The HTTP headers to be added to the GET request + :param pagination_function: A callable that generates the parameters used to call the API again. + Typically used when the API is paginated and returns for e.g a cursor, a 'next page id', or + a 'next page URL'. When provided, the Operator will call the API repeatedly until this callable + returns None. Also, the result of the Operator will become by default a list of Response.text + objects (instead of a single response object). Same with the other injected functions (like + response_check, response_filter, ...) which will also receive a list of Response object. This + function should return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`), + which will be merged and override the one used in the initial API call. :param response_check: A check against the 'requests' response object. The callable takes the response object as the first positional argument and optionally any number of keyword arguments available in the context dictionary. - It should return True for 'pass' and False otherwise. + It should return True for 'pass' and False otherwise. If a pagination_function + is provided, this function will receive a list of response objects instead of a + single response object. :param response_filter: A function allowing you to manipulate the response text. e.g response_filter=lambda response: json.loads(response.text). The callable takes the response object as the first positional argument and optionally any number of keyword arguments available in the context dictionary. + If a pagination_function is provided, this function will receive a list of response + object instead of a single response object. :param extra_options: Extra options for the 'requests' library, see the 'requests' documentation (options to modify timeout, ssl, etc.) :param log_response: Log the response (default: False) @@ -69,6 +85,7 @@ class SimpleHttpOperator(BaseOperator): :param deferrable: Run operator in the deferrable mode """ + conn_id_field = "http_conn_id" template_fields: Sequence[str] = ( "endpoint", "data", @@ -85,6 +102,7 @@ def __init__( method: str = "POST", data: Any = None, headers: dict[str, str] | None = None, + pagination_function: Callable[..., Any] | None = None, response_check: Callable[..., bool] | None = None, response_filter: Callable[..., Any] | None = None, extra_options: dict[str, Any] | None = None, @@ -104,6 +122,7 @@ def __init__( self.endpoint = endpoint self.headers = headers or {} self.data = data or {} + self.pagination_function = pagination_function self.response_check = response_check self.response_filter = response_filter self.extra_options = extra_options or {} @@ -115,42 +134,72 @@ def __init__( self.tcp_keep_alive_interval = tcp_keep_alive_interval self.deferrable = deferrable - def execute(self, context: Context) -> Any: - if self.deferrable: - self.defer( - trigger=HttpTrigger( - http_conn_id=self.http_conn_id, - auth_type=self.auth_type, - method=self.method, - endpoint=self.endpoint, - headers=self.headers, - data=self.data, - extra_options=self.extra_options, - ), - method_name="execute_complete", - ) - else: - http = HttpHook( - self.method, - http_conn_id=self.http_conn_id, + @property + def hook(self) -> HttpHook: + """Get Http Hook based on connection type.""" + conn_id = getattr(self, self.conn_id_field) + self.log.debug("Get connection for %s", conn_id) + conn = BaseHook.get_connection(conn_id) + + hook = conn.get_hook( + hook_params=dict( + method=self.method, auth_type=self.auth_type, tcp_keep_alive=self.tcp_keep_alive, tcp_keep_alive_idle=self.tcp_keep_alive_idle, tcp_keep_alive_count=self.tcp_keep_alive_count, tcp_keep_alive_interval=self.tcp_keep_alive_interval, ) + ) + return hook + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.execute_async(context=context) + else: + return self.execute_sync(context=context) - self.log.info("Calling HTTP method") + def execute_sync(self, context: Context) -> Any: + self.log.info("Calling HTTP method") + response = self.hook.run(self.endpoint, self.data, self.headers, self.extra_options) + response = self.paginate_sync(first_response=response) + return self.process_response(context=context, response=response) - response = http.run(self.endpoint, self.data, self.headers, self.extra_options) - return self.process_response(context=context, response=response) + def paginate_sync(self, first_response: Response) -> Response | list[Response]: + if not self.pagination_function: + return first_response + + all_responses = [first_response] + while True: + next_page_params = self.pagination_function(first_response) + if not next_page_params: + break + response = self.hook.run(**self._merge_next_page_parameters(next_page_params)) + all_responses.append(response) + return all_responses - def process_response(self, context: Context, response: Response) -> str: + def execute_async(self, context: Context) -> None: + self.defer( + trigger=HttpTrigger( + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + method=self.method, + endpoint=self.endpoint, + headers=self.headers, + data=self.data, + extra_options=self.extra_options, + ), + method_name="execute_complete", + ) + + def process_response(self, context: Context, response: Response | list[Response]) -> Any: """Process the response.""" from airflow.utils.operator_helpers import determine_kwargs + make_default_response: Callable = self._default_response_maker(response=response) + if self.log_response: - self.log.info(response.text) + self.log.info(make_default_response()) if self.response_check: kwargs = determine_kwargs(self.response_check, [response], context) if not self.response_check(response, **kwargs): @@ -158,9 +207,25 @@ def process_response(self, context: Context, response: Response) -> str: if self.response_filter: kwargs = determine_kwargs(self.response_filter, [response], context) return self.response_filter(response, **kwargs) - return response.text + return make_default_response() + + @staticmethod + def _default_response_maker(response: Response | list[Response]) -> Callable: + """Create a default response maker function based on the type of response. - def execute_complete(self, context: Context, event: dict): + :param response: The response object or list of response objects. + :return: A function that returns response text(s). + """ + if isinstance(response, Response): + response_object = response # Makes mypy happy + return lambda: response_object.text + + response_list: list[Response] = response # Makes mypy happy + return lambda: [entry.text for entry in response_list] + + def execute_complete( + self, context: Context, event: dict, paginated_responses: None | list[Response] = None + ): """ Callback for when the trigger fires - returns immediately. @@ -168,6 +233,102 @@ def execute_complete(self, context: Context, event: dict): """ if event["status"] == "success": response = pickle.loads(base64.standard_b64decode(event["response"])) + + self.paginate_async(context=context, response=response, previous_responses=paginated_responses) return self.process_response(context=context, response=response) else: raise AirflowException(f"Unexpected error in the operation: {event['message']}") + + def paginate_async( + self, context: Context, response: Response, previous_responses: None | list[Response] = None + ): + if self.pagination_function: + all_responses = previous_responses or [] + all_responses.append(response) + + next_page_params = self.pagination_function(response) + if not next_page_params: + return self.process_response(context=context, response=all_responses) + self.defer( + trigger=HttpTrigger( + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + method=self.method, + **self._merge_next_page_parameters(next_page_params), + ), + method_name="execute_complete", + kwargs={"paginated_responses": all_responses}, + ) + + def _merge_next_page_parameters(self, next_page_params: dict) -> dict: + """Merge initial request parameters with next page parameters. + + Merge initial requests parameters with the ones for the next page, generated by + the pagination function. Items in the 'next_page_params' overrides those defined + in the previous request. + + :param next_page_params: A dictionary containing the parameters for the next page. + :return: A dictionary containing the merged parameters. + """ + return dict( + endpoint=next_page_params.get("endpoint") or self.endpoint, + data=merge_dicts(self.data, next_page_params.get("data", {})), + headers=merge_dicts(self.headers, next_page_params.get("headers", {})), + extra_options=merge_dicts(self.extra_options, next_page_params.get("extra_options", {})), + ) + + +class SimpleHttpOperator(HttpOperator): + """ + Calls an endpoint on an HTTP system to execute an action. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:HttpOperator` + + :param http_conn_id: The :ref:`http connection` to run + the operator against + :param endpoint: The relative part of the full url. (templated) + :param method: The HTTP method to use, default = "POST" + :param data: The data to pass. POST-data in POST/PUT and params + in the URL for a GET request. (templated) + :param headers: The HTTP headers to be added to the GET request + :param pagination_function: A callable that generates the parameters used to call the API again. + Typically used when the API is paginated and returns for e.g a cursor, a 'next page id', or + a 'next page URL'. When provided, the Operator will call the API repeatedly until this callable + returns None. Also, the result of the Operator will become by default a list of Response.text + objects (instead of a single response object). Same with the other injected functions (like + response_check, response_filter, ...) which will also receive a list of Response object. This + function should return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`), + which will be merged and override the one used in the initial API call. + :param response_check: A check against the 'requests' response object. + The callable takes the response object as the first positional argument + and optionally any number of keyword arguments available in the context dictionary. + It should return True for 'pass' and False otherwise. If a pagination_function + is provided, this function will receive a list of response objects instead of a + single response object. + :param response_filter: A function allowing you to manipulate the response + text. e.g response_filter=lambda response: json.loads(response.text). + The callable takes the response object as the first positional argument + and optionally any number of keyword arguments available in the context dictionary. + If a pagination_function is provided, this function will receive a list of response + object instead of a single response object. + :param extra_options: Extra options for the 'requests' library, see the + 'requests' documentation (options to modify timeout, ssl, etc.) + :param log_response: Log the response (default: False) + :param auth_type: The auth type for the service + :param tcp_keep_alive: Enable TCP Keep Alive for the connection. + :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). + :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) + :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to + ``socket.TCP_KEEPINTVL``) + :param deferrable: Run operator in the deferrable mode + """ + + def __init__(self, **kwargs: Any): + warnings.warn( + "Class `SimpleHttpOperator` is deprecated and " + "will be removed in a future release. Please use `HttpOperator` instead.", + AirflowProviderDeprecationWarning, + ) + super().__init__(**kwargs) diff --git a/docs/apache-airflow-providers-databricks/connections/databricks.rst b/docs/apache-airflow-providers-databricks/connections/databricks.rst index 908b12eac7f03..7045d4a3ae22c 100644 --- a/docs/apache-airflow-providers-databricks/connections/databricks.rst +++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst @@ -56,7 +56,7 @@ Host (required) Login (optional) * If authentication with *Databricks login credentials* is used then specify the ``username`` used to login to Databricks. * If authentication with *Azure Service Principal* is used then specify the ID of the Azure Service Principal - * If authentication with *PAT* is used then either leave this field empty or use 'token' as login (both work, the only difference is that if login is empty then token will be sent in request header as Bearer token, if login is 'token' then it will be sent using Basic Auth which is allowed by Databricks API, this may be useful if you plan to reuse this connection with e.g. SimpleHttpOperator) + * If authentication with *PAT* is used then either leave this field empty or use 'token' as login (both work, the only difference is that if login is empty then token will be sent in request header as Bearer token, if login is 'token' then it will be sent using Basic Auth which is allowed by Databricks API, this may be useful if you plan to reuse this connection with e.g. HttpOperator) * If authentication with *Databricks Service Principal OAuth* is used then specify the ID of the Service Principal (Databricks on AWS) Password (optional) diff --git a/docs/apache-airflow-providers-http/operators.rst b/docs/apache-airflow-providers-http/operators.rst index 45fa454a29fdc..944a2931201d1 100644 --- a/docs/apache-airflow-providers-http/operators.rst +++ b/docs/apache-airflow-providers-http/operators.rst @@ -37,23 +37,23 @@ Here we are poking until httpbin gives us a response text containing ``httpbin`` :start-after: [START howto_operator_http_http_sensor_check] :end-before: [END howto_operator_http_http_sensor_check] -.. _howto/operator:SimpleHttpOperator: +.. _howto/operator:HttpOperator: -SimpleHttpOperator ------------------- +HttpOperator +------------ -Use the :class:`~airflow.providers.http.operators.http.SimpleHttpOperator` to call HTTP requests and get +Use the :class:`~airflow.providers.http.operators.http.HttpOperator` to call HTTP requests and get the response text back. -.. warning:: Configuring ``https`` via SimpleHttpOperator is counter-intuitive +.. warning:: Configuring ``https`` via HttpOperator is counter-intuitive For historical reasons, configuring ``HTTPS`` connectivity via HTTP operator is, well, difficult and counter-intuitive. The Operator defaults to ``http`` protocol and you can change the schema used by the operator via ``scheme`` connection attribute. However, this field was originally added to connection for database type of URIs, where database schemes are set traditionally as first component of URI ``path``. Therefore if you want to configure as ``https`` connection via URI, you need to pass ``https`` scheme - to the SimpleHttpOperator. AS stupid as it looks, your connection URI will look like this: - ``http://your_host:443/https``. Then if you want to use different URL paths in SimpleHttpOperator + to the HttpOperator. AS stupid as it looks, your connection URI will look like this: + ``http://your_host:443/https``. Then if you want to use different URL paths in HttpOperator you should pass your path as ``endpoint`` parameter when running the task. For example to run a query to ``https://your_host:443/my_endpoint`` you need to set the endpoint parameter to ``my_endpoint``. Alternatively, if you want, you could also percent-encode the host including the ``https://`` prefix, @@ -62,7 +62,7 @@ the response text back. In this case, however, the ``path`` will not be used at all - you still need to use ``endpoint`` parameter in the task if wish to make a request with specific path. As counter-intuitive as it is, this is historically the way how the operator/hook works and it's not easy to change without breaking - backwards compatibility because there are other operators build on top of the ``SimpleHttpOperator`` that + backwards compatibility because there are other operators build on top of the ``HttpOperator`` that rely on that functionality and there are many users using it already. @@ -81,7 +81,7 @@ Here we are calling a ``GET`` request and pass params to it. The task will succe :start-after: [START howto_operator_http_task_get_op] :end-before: [END howto_operator_http_task_get_op] -SimpleHttpOperator returns the response body as text by default. If you want to modify the response before passing +HttpOperator returns the response body as text by default. If you want to modify the response before passing it on the next task downstream use ``response_filter``. This is useful if: - the API you are consuming returns a large JSON payload and you're interested in a subset of the data @@ -118,3 +118,22 @@ Here we pass form data to a ``POST`` operation which is equal to a usual form su :language: python :start-after: [START howto_operator_http_task_post_op_formenc] :end-before: [END howto_operator_http_task_post_op_formenc] + + + +The :class:`~airflow.providers.http.operators.paginated.HttpOperator` also allows to repeatedly call an API +endpoint, typically to loop over its pages. All API responses are stored in memory by the Operator and returned +in one single result. Thus, it can be more memory and CPU intensive compared to a non-paginated call. + +By default, the result of the HttpOperator will become a list of Response.text (instead of one single +Response.text object). + +Example - Let's assume your API returns a JSON body containing a cursor: +You can write a ``pagination_function`` that will receive the raw ``request.Response`` object of your request, and +generate new request parameters (as ``dict``) based on this cursor. The HttpOperator will repeat calls to the +API until the function stop returning anything. + +.. exampleinclude:: /../../tests/system/providers/http/example_http.py + :language: python + :start-after: [START howto_operator_http_pagination_function] + :end-before: [END howto_operator_http_pagination_function] diff --git a/docs/apache-airflow/core-concepts/operators.rst b/docs/apache-airflow/core-concepts/operators.rst index f915c60c64847..ddbcf7da7ab2d 100644 --- a/docs/apache-airflow/core-concepts/operators.rst +++ b/docs/apache-airflow/core-concepts/operators.rst @@ -21,7 +21,7 @@ Operators An Operator is conceptually a template for a predefined :doc:`Task `, that you can just define declaratively inside your DAG:: with DAG("my-dag") as dag: - ping = SimpleHttpOperator(endpoint="http://example.com/update/") + ping = HttpOperator(endpoint="http://example.com/update/") email = EmailOperator(to="admin@example.com", subject="Update complete") ping >> email @@ -41,7 +41,7 @@ For a list of all core operators, see: :doc:`Core Operators and Hooks Reference If the operator you need isn't installed with Airflow by default, you can probably find it as part of our huge set of community :doc:`provider packages `. Some popular operators from here include: -- :class:`~airflow.providers.http.operators.http.SimpleHttpOperator` +- :class:`~airflow.providers.http.operators.http.HttpOperator` - :class:`~airflow.providers.mysql.operators.mysql.MySqlOperator` - :class:`~airflow.providers.postgres.operators.postgres.PostgresOperator` - :class:`~airflow.providers.microsoft.mssql.operators.mssql.MsSqlOperator` diff --git a/docs/apache-airflow/tutorial/taskflow.rst b/docs/apache-airflow/tutorial/taskflow.rst index fc7297cd783f4..08967bdde37a8 100644 --- a/docs/apache-airflow/tutorial/taskflow.rst +++ b/docs/apache-airflow/tutorial/taskflow.rst @@ -498,13 +498,13 @@ To retrieve an XCom result for a key other than ``return_value``, you can use: Using the ``.output`` property as an input to another task is supported only for operator parameters listed as a ``template_field``. -In the code example below, a :class:`~airflow.providers.http.operators.http.SimpleHttpOperator` result +In the code example below, a :class:`~airflow.providers.http.operators.http.HttpOperator` result is captured via :doc:`XComs `. This XCom result, which is the task output, is then passed to a TaskFlow function which parses the response as JSON. .. code-block:: python - get_api_results_task = SimpleHttpOperator( + get_api_results_task = HttpOperator( task_id="get_api_results", endpoint="/api/query", do_xcom_push=True, diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py index ad03ad5aeec89..2b91ad6296f68 100644 --- a/tests/providers/http/operators/test_http.py +++ b/tests/providers/http/operators/test_http.py @@ -18,6 +18,7 @@ from __future__ import annotations import base64 +import contextlib import pickle from unittest import mock @@ -25,20 +26,20 @@ from requests import Response from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.providers.http.operators.http import HttpOperator from airflow.providers.http.triggers.http import HttpTrigger @mock.patch.dict("os.environ", AIRFLOW_CONN_HTTP_EXAMPLE="http://www.example.com") -class TestSimpleHttpOp: +class TestHttpOperator: def test_response_in_logs(self, requests_mock): """ - Test that when using SimpleHttpOperator with 'GET', + Test that when using HttpOperator with 'GET', the log contains 'Example Domain' in it """ requests_mock.get("http://www.example.com", text="Example.com fake response") - operator = SimpleHttpOperator( + operator = HttpOperator( task_id="test_HTTP_op", method="GET", endpoint="/", @@ -51,7 +52,7 @@ def test_response_in_logs(self, requests_mock): def test_response_in_logs_after_failed_check(self, requests_mock): """ - Test that when using SimpleHttpOperator with log_response=True, + Test that when using HttpOperator with log_response=True, the response is logged even if request_check fails """ @@ -59,7 +60,7 @@ def response_check(response): return response.text != "invalid response" requests_mock.get("http://www.example.com", text="invalid response") - operator = SimpleHttpOperator( + operator = HttpOperator( task_id="test_HTTP_op", method="GET", endpoint="/", @@ -76,7 +77,7 @@ def response_check(response): def test_filters_response(self, requests_mock): requests_mock.get("http://www.example.com", json={"value": 5}) - operator = SimpleHttpOperator( + operator = HttpOperator( task_id="test_HTTP_op", method="GET", endpoint="/", @@ -87,7 +88,7 @@ def test_filters_response(self, requests_mock): assert result == {"value": 5} def test_async_defer_successfully(self, requests_mock): - operator = SimpleHttpOperator( + operator = HttpOperator( task_id="test_HTTP_op", deferrable=True, ) @@ -96,7 +97,7 @@ def test_async_defer_successfully(self, requests_mock): assert isinstance(exc.value.trigger, HttpTrigger), "Trigger is not a HttpTrigger" def test_async_execute_successfully(self, requests_mock): - operator = SimpleHttpOperator( + operator = HttpOperator( task_id="test_HTTP_op", deferrable=True, ) @@ -110,3 +111,81 @@ def test_async_execute_successfully(self, requests_mock): }, ) assert result == "content" + + def test_paginated_responses(self, requests_mock): + """ + Test that the HttpOperator calls repetitively the API when a + pagination_function is provided, and as long as this function returns + a dictionary that override previous' call parameters. + """ + has_returned: bool = False + + def pagination_function(response: Response) -> dict | None: + """Paginated function which returns None at the second call.""" + nonlocal has_returned + if not has_returned: + has_returned = True + return dict( + endpoint="/", + data={"cursor": "example"}, + headers={}, + extra_options={}, + ) + + requests_mock.get("http://www.example.com", json={"value": 5}) + operator = HttpOperator( + task_id="test_HTTP_op", + method="GET", + endpoint="/", + http_conn_id="HTTP_EXAMPLE", + pagination_function=pagination_function, + ) + result = operator.execute({}) + assert result == ['{"value": 5}', '{"value": 5}'] + + def test_async_paginated_responses(self, requests_mock): + """ + Test that the HttpOperator calls asynchronously and repetitively + the API when a pagination_function is provided, and as long as this function + returns a dictionary that override previous' call parameters. + """ + + def make_response_object() -> Response: + response = Response() + response._content = b'{"value": 5}' + return response + + def create_resume_response_parameters() -> dict: + response = make_response_object() + return dict( + context={}, + event={ + "status": "success", + "response": base64.standard_b64encode(pickle.dumps(response)).decode("ascii"), + }, + ) + + has_returned: bool = False + + def pagination_function(response: Response) -> dict | None: + """Paginated function which returns None at the second call.""" + nonlocal has_returned + if not has_returned: + has_returned = True + return dict(endpoint="/") + + operator = HttpOperator( + task_id="test_HTTP_op", + pagination_function=pagination_function, + deferrable=True, + ) + + # Do two calls: On the first one, the pagination_function creates a new + # deferrable trigger. On the second one, the pagination_function returns + # None, which ends the execution of the Operator + with contextlib.suppress(TaskDeferred): + operator.execute_complete(**create_resume_response_parameters()) + result = operator.execute_complete( + **create_resume_response_parameters(), paginated_responses=[make_response_object()] + ) + assert result == ['{"value": 5}', '{"value": 5}'] diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py index 8fed2ecb32d93..f842ea91fcd40 100644 --- a/tests/providers/http/sensors/test_http.py +++ b/tests/providers/http/sensors/test_http.py @@ -25,7 +25,7 @@ from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException from airflow.models.dag import DAG -from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.providers.http.operators.http import HttpOperator from airflow.providers.http.sensors.http import HttpSensor from airflow.utils.timezone import datetime @@ -293,7 +293,7 @@ def setup_method(self): @mock.patch("requests.Session", FakeSession) def test_get(self): - op = SimpleHttpOperator( + op = HttpOperator( task_id="get_op", method="GET", endpoint="/search", @@ -305,7 +305,7 @@ def test_get(self): @mock.patch("requests.Session", FakeSession) def test_get_response_check(self): - op = SimpleHttpOperator( + op = HttpOperator( task_id="get_op", method="GET", endpoint="/search", diff --git a/tests/system/providers/http/example_http.py b/tests/system/providers/http/example_http.py index 02971585639f6..9e88614b98b74 100644 --- a/tests/system/providers/http/example_http.py +++ b/tests/system/providers/http/example_http.py @@ -23,7 +23,7 @@ from datetime import datetime from airflow import DAG -from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.providers.http.operators.http import HttpOperator from airflow.providers.http.sensors.http import HttpSensor ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") @@ -42,7 +42,7 @@ # task_post_op, task_get_op and task_put_op are examples of tasks created by instantiating operators # [START howto_operator_http_task_post_op] -task_post_op = SimpleHttpOperator( +task_post_op = HttpOperator( task_id="post_op", endpoint="post", data=json.dumps({"priority": 5}), @@ -52,7 +52,7 @@ ) # [END howto_operator_http_task_post_op] # [START howto_operator_http_task_post_op_formenc] -task_post_op_formenc = SimpleHttpOperator( +task_post_op_formenc = HttpOperator( task_id="post_op_formenc", endpoint="post", data="name=Joe", @@ -61,7 +61,7 @@ ) # [END howto_operator_http_task_post_op_formenc] # [START howto_operator_http_task_get_op] -task_get_op = SimpleHttpOperator( +task_get_op = HttpOperator( task_id="get_op", method="GET", endpoint="get", @@ -71,7 +71,7 @@ ) # [END howto_operator_http_task_get_op] # [START howto_operator_http_task_get_op_response_filter] -task_get_op_response_filter = SimpleHttpOperator( +task_get_op_response_filter = HttpOperator( task_id="get_op_response_filter", method="GET", endpoint="get", @@ -80,7 +80,7 @@ ) # [END howto_operator_http_task_get_op_response_filter] # [START howto_operator_http_task_put_op] -task_put_op = SimpleHttpOperator( +task_put_op = HttpOperator( task_id="put_op", method="PUT", endpoint="put", @@ -90,7 +90,7 @@ ) # [END howto_operator_http_task_put_op] # [START howto_operator_http_task_del_op] -task_del_op = SimpleHttpOperator( +task_del_op = HttpOperator( task_id="del_op", method="DELETE", endpoint="delete", @@ -110,8 +110,32 @@ dag=dag, ) # [END howto_operator_http_http_sensor_check] +# [START howto_operator_http_pagination_function] + + +def get_next_page_cursor(response) -> dict | None: + """ + Take the raw `request.Response` object, and check for a cursor. + If a cursor exists, this function creates and return parameters to call + the next page of result. + """ + next_cursor = response.json().get("cursor") + if next_cursor: + return dict(data={"cursor": next_cursor}) + + +task_get_paginated = HttpOperator( + task_id="get_paginated", + method="GET", + endpoint="get", + data={"cursor": ""}, + pagination_function=get_next_page_cursor, + dag=dag, +) +# [END howto_operator_http_pagination_function] task_http_sensor_check >> task_post_op >> task_get_op >> task_get_op_response_filter task_get_op_response_filter >> task_put_op >> task_del_op >> task_post_op_formenc +task_post_op_formenc >> task_get_paginated from tests.system.utils import get_test_run # noqa: E402