Skip to content

Commit

Permalink
Implement AIP-60 Dataset URI formats (#37005)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Feb 26, 2024
1 parent 1d74685 commit b52b227
Show file tree
Hide file tree
Showing 31 changed files with 730 additions and 45 deletions.
79 changes: 63 additions & 16 deletions airflow/datasets/__init__.py
Expand Up @@ -18,14 +18,70 @@
from __future__ import annotations

import os
from typing import Any, Callable, ClassVar, Iterable, Iterator, Protocol, runtime_checkable
from urllib.parse import urlsplit
import urllib.parse
import warnings
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, Protocol, runtime_checkable

import attr

if TYPE_CHECKING:
from urllib.parse import SplitResult

__all__ = ["Dataset", "DatasetAll", "DatasetAny"]


def normalize_noop(parts: SplitResult) -> SplitResult:
return parts


def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None:
if scheme == "file":
return normalize_noop
from airflow.providers_manager import ProvidersManager

return ProvidersManager().dataset_uri_handlers.get(scheme)


def _sanitize_uri(uri: str) -> str:
if not uri:
raise ValueError("Dataset URI cannot be empty")
if uri.isspace():
raise ValueError("Dataset URI cannot be just whitespace")
if not uri.isascii():
raise ValueError("Dataset URI must only consist of ASCII characters")
parsed = urllib.parse.urlsplit(uri)
if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
return uri
normalized_scheme = parsed.scheme.lower()
if normalized_scheme.startswith("x-"):
return uri
if normalized_scheme == "airflow":
raise ValueError("Dataset scheme 'airflow' is reserved")
_, auth_exists, normalized_netloc = parsed.netloc.rpartition("@")
if auth_exists:
# TODO: Collect this into a DagWarning.
warnings.warn(
"A dataset URI should not contain auth info (e.g. username or "
"password). It has been automatically dropped.",
UserWarning,
stacklevel=3,
)
if parsed.query:
normalized_query = urllib.parse.urlencode(sorted(urllib.parse.parse_qsl(parsed.query)))
else:
normalized_query = ""
parsed = parsed._replace(
scheme=normalized_scheme,
netloc=normalized_netloc,
path=parsed.path.rstrip("/") or "/", # Remove all trailing slashes.
query=normalized_query,
fragment="", # Ignore any fragments.
)
if (normalizer := _get_uri_normalizer(normalized_scheme)) is not None:
parsed = normalizer(parsed)
return urllib.parse.urlunsplit(parsed)


@runtime_checkable
class BaseDatasetEventInput(Protocol):
"""Protocol for all dataset triggers to use in ``DAG(schedule=...)``.
Expand All @@ -50,23 +106,14 @@ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
class Dataset(os.PathLike, BaseDatasetEventInput):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)])
uri: str = attr.field(
converter=_sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
)
extra: dict[str, Any] | None = None

__version__: ClassVar[int] = 1

@uri.validator
def _check_uri(self, attr, uri: str) -> None:
if uri.isspace():
raise ValueError(f"{attr.name} cannot be just whitespace")
try:
uri.encode("ascii")
except UnicodeEncodeError:
raise ValueError(f"{attr.name!r} must be ascii")
parsed = urlsplit(uri)
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError(f"{attr.name!r} scheme `airflow` is reserved")

def __fspath__(self) -> str:
return self.uri

Expand All @@ -76,7 +123,7 @@ def __eq__(self, other: Any) -> bool:
else:
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
return hash(self.uri)

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/example_dags/example_datasets.py
Expand Up @@ -119,7 +119,7 @@
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=[
dag1_dataset,
Dataset("s3://this-dataset-doesnt-get-triggered"),
Dataset("s3://unrelated/this-dataset-doesnt-get-triggered"),
],
tags=["consumes", "dataset-scheduled"],
) as dag5:
Expand Down
20 changes: 20 additions & 0 deletions airflow/provider.yaml.schema.json
Expand Up @@ -196,6 +196,26 @@
"type": "string"
}
},
"dataset-uris": {
"type": "array",
"description": "Dataset URI formats",
"items": {
"type": "object",
"properties": {
"schemes": {
"type": "array",
"description": "List of supported URI schemes",
"items": {
"type": "string"
}
},
"handler": {
"type": ["string", "null"],
"description": "Normalization function for specified URI schemes. Import path to a callable taking and returning a SplitResult. 'null' specifies a no-op."
}
}
}
},
"transfers": {
"type": "array",
"items": {
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/amazon/provider.yaml
Expand Up @@ -497,6 +497,10 @@ sensors:
python-modules:
- airflow.providers.amazon.aws.sensors.quicksight

dataset-uris:
- schemes: [s3]
handler: null

filesystems:
- airflow.providers.amazon.aws.fs.s3

Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/google/datasets/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
31 changes: 31 additions & 0 deletions airflow/providers/google/datasets/bigquery.py
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format bigquery:// must contain a project ID")
if len(uri.path.split("/")) != 3: # Leading slash, database name, and table name.
raise ValueError("URI format bigquery:// must contain dataset and table names")
return uri
6 changes: 6 additions & 0 deletions airflow/providers/google/provider.yaml
Expand Up @@ -751,6 +751,12 @@ sensors:
filesystems:
- airflow.providers.google.cloud.fs.gcs

dataset-uris:
- schemes: [gcp]
handler: null
- schemes: [bigquery]
handler: airflow.providers.google.datasets.bigquery.sanitize_uri

hooks:
- integration-name: Google Ads
python-modules:
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/mysql/datasets/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
34 changes: 34 additions & 0 deletions airflow/providers/mysql/datasets/mysql.py
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format mysql:// must contain a host")
if uri.port is None:
host = uri.netloc.rstrip(":")
uri = uri._replace(netloc=f"{host}:3306")
if len(uri.path.split("/")) != 3: # Leading slash, database name, and table name.
raise ValueError("URI format mysql:// must contain database and table names")
return uri._replace(scheme="mysql")
5 changes: 4 additions & 1 deletion airflow/providers/mysql/provider.yaml
Expand Up @@ -102,7 +102,10 @@ transfers:
target-integration-name: MySQL
python-module: airflow.providers.mysql.transfers.trino_to_mysql


connection-types:
- hook-class-name: airflow.providers.mysql.hooks.mysql.MySqlHook
connection-type: mysql

dataset-uris:
- schemes: [mysql, mariadb]
handler: airflow.providers.mysql.datasets.mysql.sanitize_uri
16 changes: 16 additions & 0 deletions airflow/providers/postgres/datasets/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
37 changes: 37 additions & 0 deletions airflow/providers/postgres/datasets/postgres.py
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format postgres:// must contain a host")
if uri.port is None:
host = uri.netloc.rstrip(":")
uri = uri._replace(netloc=f"{host}:5432")
path_parts = uri.path.split("/")
if len(path_parts) != 4: # Leading slash, database, schema, and table names.
raise ValueError("URI format postgres:// must contain database, schema, and table names")
if not path_parts[2]:
path_parts[2] = "default"
return uri._replace(scheme="postgres", path="/".join(path_parts))
4 changes: 4 additions & 0 deletions airflow/providers/postgres/provider.yaml
Expand Up @@ -88,3 +88,7 @@ hooks:
connection-types:
- hook-class-name: airflow.providers.postgres.hooks.postgres.PostgresHook
connection-type: postgres

dataset-uris:
- schemes: [postgres, postgresql]
handler: airflow.providers.postgres.datasets.postgres.sanitize_uri
16 changes: 16 additions & 0 deletions airflow/providers/trino/datasets/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
34 changes: 34 additions & 0 deletions airflow/providers/trino/datasets/trino.py
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format trino:// must contain a host")
if uri.port is None:
host = uri.netloc.rstrip(":")
uri = uri._replace(netloc=f"{host}:8080")
if len(uri.path.split("/")) != 4: # Leading slash, catalog, schema, and table names.
raise ValueError("URI format trino:// must contain catalog, schema, and table names")
return uri

0 comments on commit b52b227

Please sign in to comment.