Skip to content

Commit

Permalink
Fix headers passed into HttpAsyncHook (#32409)
Browse files Browse the repository at this point in the history
  • Loading branch information
sumeshpremraj committed Jul 6, 2023
1 parent ee38382 commit 358e6e8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/http/hooks/http.py
Expand Up @@ -369,7 +369,7 @@ async def run(
url,
json=data if self.method in ("POST", "PATCH") else None,
params=data if self.method == "GET" else None,
headers=headers,
headers=_headers,
auth=auth,
**extra_options,
)
Expand Down
34 changes: 29 additions & 5 deletions tests/providers/http/hooks/test_http.py
Expand Up @@ -82,7 +82,6 @@ def test_get_request_with_port(self, mock_session, mock_request):
):
expected_url = "http://test.com:1234/some/endpoint"
for endpoint in ["some/endpoint", "/some/endpoint"]:

try:
self.get_hook.run(endpoint)
except MissingSchema:
Expand Down Expand Up @@ -175,7 +174,6 @@ def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, r

@mock.patch("airflow.providers.http.hooks.http.requests.Session")
def test_retry_on_conn_error(self, mocked_session):

retry_args = dict(
wait=tenacity.wait_none(),
stop=tenacity.stop_after_attempt(7),
Expand All @@ -192,7 +190,6 @@ def send_and_raise(unused_request, **kwargs):
assert self.get_hook._retry_obj.stop.max_attempt_number + 1 == mocked_session.call_count

def test_run_with_advanced_retry(self, requests_mock):

requests_mock.get("http://test:8080/v1/test", status_code=200, reason="OK")

retry_args = dict(
Expand Down Expand Up @@ -298,7 +295,6 @@ def test_requests_ca_bundle_env_var(self, mock_session_send):
with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_port
):

self.get_hook.run("/some/endpoint")

mock_session_send.assert_called_once_with(
Expand All @@ -317,7 +313,6 @@ def test_verify_respects_requests_ca_bundle_env_var(self, mock_session_send):
with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_port
):

self.get_hook.run("/some/endpoint", extra_options={"verify": True})

mock_session_send.assert_called_once_with(
Expand Down Expand Up @@ -530,3 +525,32 @@ async def test_async_post_request_with_error_code(aioresponse):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
with pytest.raises(AirflowException):
await hook.run("v1/test")


@pytest.mark.asyncio
async def test_async_request_uses_connection_extra(aioresponse):
"""Test api call asynchronously with a connection that has extra field."""

connection_extra = {"bareer": "test"}
connection_id = "http_default"

def get_airflow_connection_with_extra(unused_conn_id=None):
return Connection(
conn_id=connection_id, conn_type="http", host="test:8080/", extra=json.dumps(connection_extra)
)

aioresponse.post(
"http://test:8080/v1/test",
status=200,
payload='{"status":{"status": 200}}',
reason="OK",
)

with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_extra
):
hook = HttpAsyncHook()
with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function:
await hook.run("v1/test")
headers = mocked_function.call_args.kwargs.get("headers")
assert all(key in headers and headers[key] == value for key, value in connection_extra.items())

0 comments on commit 358e6e8

Please sign in to comment.