Skip to content

Commit

Permalink
Change the URL building in HttpHookAsync to match the behavior of Htt…
Browse files Browse the repository at this point in the history
…pHook (apache#37696)

They are moved from airflow.models.datasets to airflow.datasets since
the intention is to use them with Dataset, not DatasetModel. It is more
natural for users to import from the latter module instead.

A new (abstract) base class is added for the two classes, plus the OG
Dataset class, to inherit from. This allows us to replace a few
isinstance checks with simple molymorphism and make the logic a bit
simpler.

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
Co-authored-by: Wei Lee <weilee.rx@gmail.com>
Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>
  • Loading branch information
4 people authored and abhishekbhakat committed Mar 5, 2024
1 parent db6a081 commit 23a686e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
19 changes: 9 additions & 10 deletions airflow/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
from airflow.models import Connection


def _url_from_endpoint(base_url: str | None, endpoint: str | None) -> str:
"""Combine base url with endpoint."""
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
return f"{base_url}/{endpoint}"
return (base_url or "") + (endpoint or "")


class HttpHook(BaseHook):
"""Interact with HTTP servers.
Expand Down Expand Up @@ -158,7 +165,7 @@ def run(

session = self.get_conn(headers)

url = self.url_from_endpoint(endpoint)
url = _url_from_endpoint(self.base_url, endpoint)

if self.tcp_keep_alive:
keep_alive_adapter = TCPKeepAliveAdapter(
Expand Down Expand Up @@ -261,12 +268,6 @@ def run_with_advanced_retry(self, _retry_args: dict[Any, Any], *args: Any, **kwa
# TODO: remove ignore type when https://github.com/jd/tenacity/issues/428 is resolved
return self._retry_obj(self.run, *args, **kwargs) # type: ignore

def url_from_endpoint(self, endpoint: str | None) -> str:
"""Combine base url with endpoint."""
if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
return self.base_url + "/" + endpoint
return (self.base_url or "") + (endpoint or "")

def test_connection(self):
"""Test HTTP Connection."""
try:
Expand Down Expand Up @@ -357,9 +358,7 @@ async def run(
if headers:
_headers.update(headers)

base_url = (self.base_url or "").rstrip("/")
endpoint = (endpoint or "").lstrip("/")
url = f"{base_url}/{endpoint}"
url = _url_from_endpoint(self.base_url, endpoint)

async with aiohttp.ClientSession() as session:
if self.method == "GET":
Expand Down
22 changes: 22 additions & 0 deletions tests/providers/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,25 @@ def test_process_extra_options_from_connection(self):
"max_redirects": 3,
}
assert actual == {"bearer": "test"}

@pytest.mark.asyncio
async def test_build_request_url_from_connection(self):
conn = get_airflow_connection()
schema = conn.schema or "http" # default to http
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
hook = HttpAsyncHook()
with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function:
await hook.run("v1/test")
assert mocked_function.call_args.args[0] == f"{schema}://{conn.host}v1/test"

@pytest.mark.asyncio
async def test_build_request_url_from_endpoint_param(self):
def get_empty_conn(conn_id: str = "http_default"):
return Connection(conn_id=conn_id, conn_type="http")

hook = HttpAsyncHook()
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_empty_conn), mock.patch(
"aiohttp.ClientSession.post", new_callable=mock.AsyncMock
) as mocked_function:
await hook.run("test.com:8080/v1/test")
assert mocked_function.call_args.args[0] == "http://test.com:8080/v1/test"

0 comments on commit 23a686e

Please sign in to comment.