Skip to content

Commit

Permalink
Fix HttpOperator pagination with str data (#35782)
Browse files Browse the repository at this point in the history
* feat: Restrict `data` parameter typing

Follows the hook's typing

* feat: Implement `data` override when string

* feat: Improve docstring about merging and overriding behavior

* fix: Add correct typing for mypy

* feat: add test

* fix: remove unused imports

* fix: Update SimpleHttpOperator docstring

* feat: Correctly test parameters overriding
  • Loading branch information
joffreybienvenu-infrabel committed Nov 22, 2023
1 parent 172f573 commit 5588a95
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 31 deletions.
50 changes: 34 additions & 16 deletions airflow/providers/http/operators/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,17 @@ class HttpOperator(BaseOperator):
:param pagination_function: A callable that generates the parameters used to call the API again,
based on the previous response. Typically used when the API is paginated and returns for e.g a
cursor, a 'next page id', or a 'next page URL'. When provided, the Operator will call the API
repeatedly until this callable returns None. Also, the result of the Operator will become by
default a list of Response.text objects (instead of a single response object). Same with the
other injected functions (like response_check, response_filter, ...) which will also receive a
list of Response object. This function receives a Response object form previous call, and should
return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`), which will be merged
and will override the one used in the initial API call.
repeatedly until this callable returns None. The result of the Operator will become by default a
list of Response.text objects (instead of a single response object). Same with the other injected
functions (like response_check, response_filter, ...) which will also receive a list of Response
objects. This function receives a Response object form previous call, and should return a nested
dictionary with the following optional keys: `endpoint`, `data`, `headers` and `extra_options.
Those keys will be merged and/or override the parameters provided into the HttpOperator declaration.
Parameters are merged when they are both a dictionary (e.g.: HttpOperator.headers will be merged
with the `headers` dict provided by this function). When merging, dict items returned by this
function will override initial ones (e.g: if both HttpOperator.headers and `headers` have a 'cookie'
item, the one provided by `headers` is kept). Parameters are simply overridden when any of them are
string (e.g.: HttpOperator.endpoint is overridden by `endpoint`).
:param response_check: A check against the 'requests' response object.
The callable takes the response object as the first positional argument
and optionally any number of keyword arguments available in the context dictionary.
Expand Down Expand Up @@ -101,7 +106,7 @@ def __init__(
*,
endpoint: str | None = None,
method: str = "POST",
data: Any = None,
data: dict[str, Any] | str | None = None,
headers: dict[str, str] | None = None,
pagination_function: Callable[..., Any] | None = None,
response_check: Callable[..., bool] | None = None,
Expand Down Expand Up @@ -271,9 +276,16 @@ def _merge_next_page_parameters(self, next_page_params: dict) -> dict:
:param next_page_params: A dictionary containing the parameters for the next page.
:return: A dictionary containing the merged parameters.
"""
data: str | dict | None = None # makes mypy happy
next_page_data_param = next_page_params.get("data")
if isinstance(self.data, dict) and isinstance(next_page_data_param, dict):
data = merge_dicts(self.data, next_page_data_param)
else:
data = next_page_data_param or self.data

return dict(
endpoint=next_page_params.get("endpoint") or self.endpoint,
data=merge_dicts(self.data, next_page_params.get("data", {})),
data=data,
headers=merge_dicts(self.headers, next_page_params.get("headers", {})),
extra_options=merge_dicts(self.extra_options, next_page_params.get("extra_options", {})),
)
Expand All @@ -294,14 +306,20 @@ class SimpleHttpOperator(HttpOperator):
:param data: The data to pass. POST-data in POST/PUT and params
in the URL for a GET request. (templated)
:param headers: The HTTP headers to be added to the GET request
:param pagination_function: A callable that generates the parameters used to call the API again.
Typically used when the API is paginated and returns for e.g a cursor, a 'next page id', or
a 'next page URL'. When provided, the Operator will call the API repeatedly until this callable
returns None. Also, the result of the Operator will become by default a list of Response.text
objects (instead of a single response object). Same with the other injected functions (like
response_check, response_filter, ...) which will also receive a list of Response object. This
function should return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`),
which will be merged and override the one used in the initial API call.
:param pagination_function: A callable that generates the parameters used to call the API again,
based on the previous response. Typically used when the API is paginated and returns for e.g a
cursor, a 'next page id', or a 'next page URL'. When provided, the Operator will call the API
repeatedly until this callable returns None. The result of the Operator will become by default a
list of Response.text objects (instead of a single response object). Same with the other injected
functions (like response_check, response_filter, ...) which will also receive a list of Response
objects. This function receives a Response object form previous call, and should return a nested
dictionary with the following optional keys: `endpoint`, `data`, `headers` and `extra_options.
Those keys will be merged and/or override the parameters provided into the HttpOperator declaration.
Parameters are merged when they are both a dictionary (e.g.: HttpOperator.headers will be merged
with the `headers` dict provided by this function). When merging, dict items returned by this
function will override initial ones (e.g: if both HttpOperator.headers and `headers` have a 'cookie'
item, the one provided by `headers` is kept). Parameters are simply overridden when any of them are
string (e.g.: HttpOperator.endpoint is overridden by `endpoint`).
:param response_check: A check against the 'requests' response object.
The callable takes the response object as the first positional argument
and optionally any number of keyword arguments available in the context dictionary.
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/http/triggers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
method: str = "POST",
endpoint: str | None = None,
headers: dict[str, str] | None = None,
data: Any = None,
data: dict[str, Any] | str | None = None,
extra_options: dict[str, Any] | None = None,
):
super().__init__()
Expand Down
63 changes: 49 additions & 14 deletions tests/providers/http/operators/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pytest
from requests import Response
from requests.models import RequestEncodingMixin

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.http.operators.http import HttpOperator
Expand Down Expand Up @@ -112,41 +113,75 @@ def test_async_execute_successfully(self, requests_mock):
)
assert result == "content"

def test_paginated_responses(self, requests_mock):
@pytest.mark.parametrize(
"data, headers, extra_options, pagination_data, pagination_headers, pagination_extra_options",
[
({"data": 1}, {"x-head": "1"}, {"verify": False}, {"data": 2}, {"x-head": "0"}, {"verify": True}),
("data foo", {"x-head": "1"}, {"verify": False}, {"data": 2}, {"x-head": "0"}, {"verify": True}),
("data foo", {"x-head": "1"}, {"verify": False}, "data bar", {"x-head": "0"}, {"verify": True}),
({"data": 1}, {"x-head": "1"}, {"verify": False}, "data foo", {"x-head": "0"}, {"verify": True}),
],
)
def test_pagination(
self,
requests_mock,
data,
headers,
extra_options,
pagination_data,
pagination_headers,
pagination_extra_options,
):
"""
Test that the HttpOperator calls repetitively the API when a
pagination_function is provided, and as long as this function returns
a dictionary that override previous' call parameters.
"""
iterations: int = 0
is_second_call: bool = False

def pagination_function(response: Response) -> dict | None:
"""Paginated function which returns None at the second call."""
nonlocal iterations
if iterations < 2:
iterations += 1
nonlocal is_second_call
if not is_second_call:
is_second_call = True
return dict(
endpoint=response.json()["endpoint"],
data={},
headers={},
extra_options={},
data=pagination_data,
headers=pagination_headers,
extra_options=pagination_extra_options,
)
return None

requests_mock.get("http://www.example.com/foo", json={"value": 5, "endpoint": "bar"})
requests_mock.get("http://www.example.com/bar", json={"value": 10, "endpoint": "foo"})
first_endpoint = requests_mock.post("http://www.example.com/1", json={"value": 5, "endpoint": "2"})
second_endpoint = requests_mock.post("http://www.example.com/2", json={"value": 10, "endpoint": "3"})
operator = HttpOperator(
task_id="test_HTTP_op",
method="GET",
endpoint="/foo",
method="POST",
endpoint="/1",
data=data,
headers=headers,
extra_options=extra_options,
http_conn_id="HTTP_EXAMPLE",
pagination_function=pagination_function,
response_filter=lambda resp: [entry.json()["value"] for entry in resp],
)
result = operator.execute({})
assert result == [5, 10, 5]

def test_async_paginated_responses(self, requests_mock):
# Ensure the initial call is made with parameters passed to the Operator
first_call = first_endpoint.request_history[0]
assert first_call.headers.items() >= headers.items()
assert first_call.body == RequestEncodingMixin._encode_params(data)
assert first_call.verify is extra_options["verify"]

# Ensure the second - paginated - call is made with parameters merged from the pagination function
second_call = second_endpoint.request_history[0]
assert second_call.headers.items() >= pagination_headers.items()
assert second_call.body == RequestEncodingMixin._encode_params(pagination_data)
assert second_call.verify is pagination_extra_options["verify"]

assert result == [5, 10]

def test_async_pagination(self, requests_mock):
"""
Test that the HttpOperator calls asynchronously and repetitively
the API when a pagination_function is provided, and as long as this function
Expand Down

0 comments on commit 5588a95

Please sign in to comment.