Skip to content

Commit

Permalink
Add pagination to HttpOperator and make it more modular (#34669)
Browse files Browse the repository at this point in the history
* feat: Make SimpleHttpOperator extendable

* feat: Implement ExtendedHttpOperator

* feat: Add sync and async tests for `pagination_function`

* feat: Add example and documentation

* fix: Add missing return statements

* fix: typo in class docstring

Co-authored-by: Jens Scheffler <95105677+jens-scheffler-bosch@users.noreply.github.com>

* fix: make use of hook property in DiscordWebhookHook

* fix: rename to PaginatedHttpOperator

* fix: Correctly route reference link to PaginatedHttpOperator docs

* fix: makes SimpleHttpOperator types customizable for mypy

* fix: add missing dashes in docs + add missing reference to paginated operator

* fix: add missing reference to `PaginatedHttpOperator`

* feat: implement hook retrieval based on connection id

* feat: Merge PaginatedOperator to SimpleHttpOperator

* fix: Removes mention of PaginatedHttpOperator

* fix: Apply static checks code quality

* fix: Reformulate docs

* feat: Deprecate `SimpleHttpOperator` and rename to `HttpOperator`

* fix: Remove 'HttpOperator' from `__deprecated_classes`

---------

Co-authored-by: Jens Scheffler <95105677+jens-scheffler-bosch@users.noreply.github.com>
  • Loading branch information
Joffreybvn and jscheffl committed Nov 3, 2023
1 parent 829d10a commit 70b3bd3
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 70 deletions.
15 changes: 9 additions & 6 deletions airflow/providers/discord/operators/discord_webhook.py
Expand Up @@ -21,13 +21,13 @@

from airflow.exceptions import AirflowException
from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook
from airflow.providers.http.operators.http import SimpleHttpOperator
from airflow.providers.http.operators.http import HttpOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class DiscordWebhookOperator(SimpleHttpOperator):
class DiscordWebhookOperator(HttpOperator):
"""
This operator allows you to post messages to Discord using incoming webhooks.
Expand Down Expand Up @@ -77,11 +77,10 @@ def __init__(
self.avatar_url = avatar_url
self.tts = tts
self.proxy = proxy
self.hook: DiscordWebhookHook | None = None

def execute(self, context: Context) -> None:
"""Call the DiscordWebhookHook to post message."""
self.hook = DiscordWebhookHook(
@property
def hook(self) -> DiscordWebhookHook:
hook = DiscordWebhookHook(
self.http_conn_id,
self.webhook_endpoint,
self.message,
Expand All @@ -90,4 +89,8 @@ def execute(self, context: Context) -> None:
self.tts,
self.proxy,
)
return hook

def execute(self, context: Context) -> None:
"""Call the DiscordWebhookHook to post a message."""
self.hook.execute()
223 changes: 192 additions & 31 deletions airflow/providers/http/operators/http.py
Expand Up @@ -19,28 +19,32 @@

import base64
import pickle
import warnings
from typing import TYPE_CHECKING, Any, Callable, Sequence

from requests import Response

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.triggers.http import HttpTrigger
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
from requests import Response
from requests.auth import AuthBase

from airflow.providers.http.hooks.http import HttpHook
from airflow.utils.context import Context


class SimpleHttpOperator(BaseOperator):
class HttpOperator(BaseOperator):
"""
Calls an endpoint on an HTTP system to execute an action.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SimpleHttpOperator`
:ref:`howto/operator:HttpOperator`
:param http_conn_id: The :ref:`http connection<howto/connection:http>` to run
the operator against
Expand All @@ -49,14 +53,26 @@ class SimpleHttpOperator(BaseOperator):
:param data: The data to pass. POST-data in POST/PUT and params
in the URL for a GET request. (templated)
:param headers: The HTTP headers to be added to the GET request
:param pagination_function: A callable that generates the parameters used to call the API again.
Typically used when the API is paginated and returns for e.g a cursor, a 'next page id', or
a 'next page URL'. When provided, the Operator will call the API repeatedly until this callable
returns None. Also, the result of the Operator will become by default a list of Response.text
objects (instead of a single response object). Same with the other injected functions (like
response_check, response_filter, ...) which will also receive a list of Response object. This
function should return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`),
which will be merged and override the one used in the initial API call.
:param response_check: A check against the 'requests' response object.
The callable takes the response object as the first positional argument
and optionally any number of keyword arguments available in the context dictionary.
It should return True for 'pass' and False otherwise.
It should return True for 'pass' and False otherwise. If a pagination_function
is provided, this function will receive a list of response objects instead of a
single response object.
:param response_filter: A function allowing you to manipulate the response
text. e.g response_filter=lambda response: json.loads(response.text).
The callable takes the response object as the first positional argument
and optionally any number of keyword arguments available in the context dictionary.
If a pagination_function is provided, this function will receive a list of response
object instead of a single response object.
:param extra_options: Extra options for the 'requests' library, see the
'requests' documentation (options to modify timeout, ssl, etc.)
:param log_response: Log the response (default: False)
Expand All @@ -69,6 +85,7 @@ class SimpleHttpOperator(BaseOperator):
:param deferrable: Run operator in the deferrable mode
"""

conn_id_field = "http_conn_id"
template_fields: Sequence[str] = (
"endpoint",
"data",
Expand All @@ -85,6 +102,7 @@ def __init__(
method: str = "POST",
data: Any = None,
headers: dict[str, str] | None = None,
pagination_function: Callable[..., Any] | None = None,
response_check: Callable[..., bool] | None = None,
response_filter: Callable[..., Any] | None = None,
extra_options: dict[str, Any] | None = None,
Expand All @@ -104,6 +122,7 @@ def __init__(
self.endpoint = endpoint
self.headers = headers or {}
self.data = data or {}
self.pagination_function = pagination_function
self.response_check = response_check
self.response_filter = response_filter
self.extra_options = extra_options or {}
Expand All @@ -115,59 +134,201 @@ def __init__(
self.tcp_keep_alive_interval = tcp_keep_alive_interval
self.deferrable = deferrable

def execute(self, context: Context) -> Any:
if self.deferrable:
self.defer(
trigger=HttpTrigger(
http_conn_id=self.http_conn_id,
auth_type=self.auth_type,
method=self.method,
endpoint=self.endpoint,
headers=self.headers,
data=self.data,
extra_options=self.extra_options,
),
method_name="execute_complete",
)
else:
http = HttpHook(
self.method,
http_conn_id=self.http_conn_id,
@property
def hook(self) -> HttpHook:
"""Get Http Hook based on connection type."""
conn_id = getattr(self, self.conn_id_field)
self.log.debug("Get connection for %s", conn_id)
conn = BaseHook.get_connection(conn_id)

hook = conn.get_hook(
hook_params=dict(
method=self.method,
auth_type=self.auth_type,
tcp_keep_alive=self.tcp_keep_alive,
tcp_keep_alive_idle=self.tcp_keep_alive_idle,
tcp_keep_alive_count=self.tcp_keep_alive_count,
tcp_keep_alive_interval=self.tcp_keep_alive_interval,
)
)
return hook

def execute(self, context: Context) -> Any:
if self.deferrable:
self.execute_async(context=context)
else:
return self.execute_sync(context=context)

self.log.info("Calling HTTP method")
def execute_sync(self, context: Context) -> Any:
self.log.info("Calling HTTP method")
response = self.hook.run(self.endpoint, self.data, self.headers, self.extra_options)
response = self.paginate_sync(first_response=response)
return self.process_response(context=context, response=response)

response = http.run(self.endpoint, self.data, self.headers, self.extra_options)
return self.process_response(context=context, response=response)
def paginate_sync(self, first_response: Response) -> Response | list[Response]:
if not self.pagination_function:
return first_response

all_responses = [first_response]
while True:
next_page_params = self.pagination_function(first_response)
if not next_page_params:
break
response = self.hook.run(**self._merge_next_page_parameters(next_page_params))
all_responses.append(response)
return all_responses

def process_response(self, context: Context, response: Response) -> str:
def execute_async(self, context: Context) -> None:
self.defer(
trigger=HttpTrigger(
http_conn_id=self.http_conn_id,
auth_type=self.auth_type,
method=self.method,
endpoint=self.endpoint,
headers=self.headers,
data=self.data,
extra_options=self.extra_options,
),
method_name="execute_complete",
)

def process_response(self, context: Context, response: Response | list[Response]) -> Any:
"""Process the response."""
from airflow.utils.operator_helpers import determine_kwargs

make_default_response: Callable = self._default_response_maker(response=response)

if self.log_response:
self.log.info(response.text)
self.log.info(make_default_response())
if self.response_check:
kwargs = determine_kwargs(self.response_check, [response], context)
if not self.response_check(response, **kwargs):
raise AirflowException("Response check returned False.")
if self.response_filter:
kwargs = determine_kwargs(self.response_filter, [response], context)
return self.response_filter(response, **kwargs)
return response.text
return make_default_response()

@staticmethod
def _default_response_maker(response: Response | list[Response]) -> Callable:
"""Create a default response maker function based on the type of response.
def execute_complete(self, context: Context, event: dict):
:param response: The response object or list of response objects.
:return: A function that returns response text(s).
"""
if isinstance(response, Response):
response_object = response # Makes mypy happy
return lambda: response_object.text

response_list: list[Response] = response # Makes mypy happy
return lambda: [entry.text for entry in response_list]

def execute_complete(
self, context: Context, event: dict, paginated_responses: None | list[Response] = None
):
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was successful.
"""
if event["status"] == "success":
response = pickle.loads(base64.standard_b64decode(event["response"]))

self.paginate_async(context=context, response=response, previous_responses=paginated_responses)
return self.process_response(context=context, response=response)
else:
raise AirflowException(f"Unexpected error in the operation: {event['message']}")

def paginate_async(
self, context: Context, response: Response, previous_responses: None | list[Response] = None
):
if self.pagination_function:
all_responses = previous_responses or []
all_responses.append(response)

next_page_params = self.pagination_function(response)
if not next_page_params:
return self.process_response(context=context, response=all_responses)
self.defer(
trigger=HttpTrigger(
http_conn_id=self.http_conn_id,
auth_type=self.auth_type,
method=self.method,
**self._merge_next_page_parameters(next_page_params),
),
method_name="execute_complete",
kwargs={"paginated_responses": all_responses},
)

def _merge_next_page_parameters(self, next_page_params: dict) -> dict:
"""Merge initial request parameters with next page parameters.
Merge initial requests parameters with the ones for the next page, generated by
the pagination function. Items in the 'next_page_params' overrides those defined
in the previous request.
:param next_page_params: A dictionary containing the parameters for the next page.
:return: A dictionary containing the merged parameters.
"""
return dict(
endpoint=next_page_params.get("endpoint") or self.endpoint,
data=merge_dicts(self.data, next_page_params.get("data", {})),
headers=merge_dicts(self.headers, next_page_params.get("headers", {})),
extra_options=merge_dicts(self.extra_options, next_page_params.get("extra_options", {})),
)


class SimpleHttpOperator(HttpOperator):
"""
Calls an endpoint on an HTTP system to execute an action.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:HttpOperator`
:param http_conn_id: The :ref:`http connection<howto/connection:http>` to run
the operator against
:param endpoint: The relative part of the full url. (templated)
:param method: The HTTP method to use, default = "POST"
:param data: The data to pass. POST-data in POST/PUT and params
in the URL for a GET request. (templated)
:param headers: The HTTP headers to be added to the GET request
:param pagination_function: A callable that generates the parameters used to call the API again.
Typically used when the API is paginated and returns for e.g a cursor, a 'next page id', or
a 'next page URL'. When provided, the Operator will call the API repeatedly until this callable
returns None. Also, the result of the Operator will become by default a list of Response.text
objects (instead of a single response object). Same with the other injected functions (like
response_check, response_filter, ...) which will also receive a list of Response object. This
function should return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`),
which will be merged and override the one used in the initial API call.
:param response_check: A check against the 'requests' response object.
The callable takes the response object as the first positional argument
and optionally any number of keyword arguments available in the context dictionary.
It should return True for 'pass' and False otherwise. If a pagination_function
is provided, this function will receive a list of response objects instead of a
single response object.
:param response_filter: A function allowing you to manipulate the response
text. e.g response_filter=lambda response: json.loads(response.text).
The callable takes the response object as the first positional argument
and optionally any number of keyword arguments available in the context dictionary.
If a pagination_function is provided, this function will receive a list of response
object instead of a single response object.
:param extra_options: Extra options for the 'requests' library, see the
'requests' documentation (options to modify timeout, ssl, etc.)
:param log_response: Log the response (default: False)
:param auth_type: The auth type for the service
:param tcp_keep_alive: Enable TCP Keep Alive for the connection.
:param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``).
:param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``)
:param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to
``socket.TCP_KEEPINTVL``)
:param deferrable: Run operator in the deferrable mode
"""

def __init__(self, **kwargs: Any):
warnings.warn(
"Class `SimpleHttpOperator` is deprecated and "
"will be removed in a future release. Please use `HttpOperator` instead.",
AirflowProviderDeprecationWarning,
)
super().__init__(**kwargs)
Expand Up @@ -56,7 +56,7 @@ Host (required)
Login (optional)
* If authentication with *Databricks login credentials* is used then specify the ``username`` used to login to Databricks.
* If authentication with *Azure Service Principal* is used then specify the ID of the Azure Service Principal
* If authentication with *PAT* is used then either leave this field empty or use 'token' as login (both work, the only difference is that if login is empty then token will be sent in request header as Bearer token, if login is 'token' then it will be sent using Basic Auth which is allowed by Databricks API, this may be useful if you plan to reuse this connection with e.g. SimpleHttpOperator)
* If authentication with *PAT* is used then either leave this field empty or use 'token' as login (both work, the only difference is that if login is empty then token will be sent in request header as Bearer token, if login is 'token' then it will be sent using Basic Auth which is allowed by Databricks API, this may be useful if you plan to reuse this connection with e.g. HttpOperator)
* If authentication with *Databricks Service Principal OAuth* is used then specify the ID of the Service Principal (Databricks on AWS)

Password (optional)
Expand Down

0 comments on commit 70b3bd3

Please sign in to comment.