diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index 1ab48ec9c99b5..203eceb1d0c96 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from urllib.parse import urlparse +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import sqlalchemy_jsonfield from sqlalchemy import ( @@ -35,6 +35,7 @@ from airflow.datasets import Dataset from airflow.models.base import ID_LEN, Base, StringID +from airflow.models.connection import Connection from airflow.settings import json from airflow.utils import timezone from airflow.utils.sqlalchemy import UtcDateTime @@ -83,9 +84,6 @@ def __init__(self, uri: str, **kwargs): uri.encode('ascii') except UnicodeEncodeError: raise ValueError('URI must be ascii') - parsed = urlparse(uri) - if parsed.scheme and parsed.scheme.lower() == 'airflow': - raise ValueError("Scheme `airflow` is reserved.") super().__init__(uri=uri, **kwargs) def __eq__(self, other): @@ -100,6 +98,64 @@ def __hash__(self): def __repr__(self): return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" + @property + def canonical_uri(self): + """ + Resolve the canonical uri for a dataset. + + If the uri doesn't have an `airflow` scheme, return it as-is. + + If it does have an `airflow` scheme, it takes the connection id from + the username in userinfo. It then will combine the connection uri and + dataset uri to form the canonical uri. It does this by: + + * Using the scheme from the connection, unless an override is provided + in the dataset scheme (e.g. airflow+override://) + * Determine the hostname and port, where the dataset values take precedence + * Combine the path, connection first followed by the dataset path + * Merge the query args + + # airflow://conn_id/... + # airflow+override://conn_id/... + # airflow://conn_id/some_extra_path?query + """ + parsed = urlparse(self.uri) + + if not parsed.scheme.startswith("airflow"): + return self.uri + + conn_id = parsed.username + conn = urlparse(Connection.get_connection_from_secrets(conn_id).get_uri()) + + # Take the scheme from the connection, unless it is overridden in the dataset + scheme = conn.scheme + split_scheme = parsed.scheme.split("+") + if len(split_scheme) == 2: + scheme = split_scheme[1] + + # Strip userinfo from the uri + # Allow hostname/port override + hostname = parsed.hostname or conn.hostname + port = parsed.port or conn.port + netloc = hostname + if port: + netloc = f"{hostname}:{port}" + + # Combine the paths (connection followed by dataset) + path = conn.path + if parsed.path: + path = f"{path}{parsed.path}" + if path == "//": + path = "/" + + # Merge the query args + query = parse_qs(conn.query) + if parsed.query: + query.update(parse_qs(parsed.query)) + + merged_conn = (scheme, netloc, path, "", urlencode(query, doseq=True), conn.fragment) + return urlunparse(merged_conn) + class DatasetDagRef(Base): """References from a DAG to an upstream dataset.""" diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index 1056d9a48d7ff..0fd1ababb74ce 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. +from unittest import mock + import pytest from airflow.datasets import Dataset +from airflow.models.dataset import DatasetModel from airflow.operators.empty import EmptyOperator @@ -32,14 +35,41 @@ def test_uri_with_scheme(self, dag_maker, session): with dag_maker(dag_id="example_dataset"): EmptyOperator(task_id="task1", outlets=[dataset]) - def test_uri_with_airflow_scheme_restricted(self, dag_maker, session): - dataset = Dataset(uri="airflow://example_dataset") - with pytest.raises(ValueError, match='Scheme `airflow` is reserved'): - with dag_maker(dag_id="example_dataset"): - EmptyOperator(task_id="task1", outlets=[dataset]) - def test_uri_with_invalid_characters(self, dag_maker, session): dataset = Dataset(uri="èxample_datašet") with pytest.raises(ValueError, match='URI must be ascii'): with dag_maker(dag_id="example_dataset"): EmptyOperator(task_id="task1", outlets=[dataset]) + + +class TestDatasetModel: + @pytest.mark.parametrize( + "conn_uri, dataset_uri, expected_canonical_uri", + [ + ("postgres://somehost/", "airflow://testconn@/", "postgres://somehost/"), + ("postgres://somehost:111/base", "airflow://testconn@", "postgres://somehost:111/base"), + ("postgres://somehost:111/base", "airflow+foo://testconn@", "foo://somehost:111/base"), + ("postgres://somehost:111", "airflow://testconn@foo:222", "postgres://foo:222"), + ( + "postgres://somehost:111/base", + "airflow://testconn@/extra", + "postgres://somehost:111/base/extra", + ), + ("postgres://somehost:111", "airflow://testconn@/?foo=bar", "postgres://somehost:111/?foo=bar"), + ( + "postgres://somehost?biz=baz", + "airflow://testconn@/?foo=bar", + "postgres://somehost/?biz=baz&foo=bar", + ), + ( + "postgres://somehost?foo=baz", + "airflow://testconn@/?foo=bar", + "postgres://somehost/?foo=bar", + ), + ("postgres://user:pass@somehost", "airflow://testconn@", "postgres://somehost"), + ], + ) + def test_canonical_uri(self, conn_uri, dataset_uri, expected_canonical_uri): + with mock.patch.dict('os.environ', AIRFLOW_CONN_TESTCONN=conn_uri): + dataset = DatasetModel(uri=dataset_uri) + assert dataset.canonical_uri == expected_canonical_uri