diff --git a/airflow/providers/http/hooks/http.py b/airflow/providers/http/hooks/http.py index 5bff4716a9542..0b0443efc484d 100644 --- a/airflow/providers/http/hooks/http.py +++ b/airflow/providers/http/hooks/http.py @@ -369,7 +369,7 @@ async def run( url, json=data if self.method in ("POST", "PATCH") else None, params=data if self.method == "GET" else None, - headers=headers, + headers=_headers, auth=auth, **extra_options, ) diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py index 8d2bc8691d9a4..e702dded9c102 100644 --- a/tests/providers/http/hooks/test_http.py +++ b/tests/providers/http/hooks/test_http.py @@ -82,7 +82,6 @@ def test_get_request_with_port(self, mock_session, mock_request): ): expected_url = "http://test.com:1234/some/endpoint" for endpoint in ["some/endpoint", "/some/endpoint"]: - try: self.get_hook.run(endpoint) except MissingSchema: @@ -175,7 +174,6 @@ def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, r @mock.patch("airflow.providers.http.hooks.http.requests.Session") def test_retry_on_conn_error(self, mocked_session): - retry_args = dict( wait=tenacity.wait_none(), stop=tenacity.stop_after_attempt(7), @@ -192,7 +190,6 @@ def send_and_raise(unused_request, **kwargs): assert self.get_hook._retry_obj.stop.max_attempt_number + 1 == mocked_session.call_count def test_run_with_advanced_retry(self, requests_mock): - requests_mock.get("http://test:8080/v1/test", status_code=200, reason="OK") retry_args = dict( @@ -298,7 +295,6 @@ def test_requests_ca_bundle_env_var(self, mock_session_send): with mock.patch( "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_port ): - self.get_hook.run("/some/endpoint") mock_session_send.assert_called_once_with( @@ -317,7 +313,6 @@ def test_verify_respects_requests_ca_bundle_env_var(self, mock_session_send): with mock.patch( "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_port ): - self.get_hook.run("/some/endpoint", extra_options={"verify": True}) mock_session_send.assert_called_once_with( @@ -530,3 +525,32 @@ async def test_async_post_request_with_error_code(aioresponse): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): with pytest.raises(AirflowException): await hook.run("v1/test") + + +@pytest.mark.asyncio +async def test_async_request_uses_connection_extra(aioresponse): + """Test api call asynchronously with a connection that has extra field.""" + + connection_extra = {"bareer": "test"} + connection_id = "http_default" + + def get_airflow_connection_with_extra(unused_conn_id=None): + return Connection( + conn_id=connection_id, conn_type="http", host="test:8080/", extra=json.dumps(connection_extra) + ) + + aioresponse.post( + "http://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", + ) + + with mock.patch( + "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_extra + ): + hook = HttpAsyncHook() + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: + await hook.run("v1/test") + headers = mocked_function.call_args.kwargs.get("headers") + assert all(key in headers and headers[key] == value for key, value in connection_extra.items())