Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AIP-60 Dataset URI formats #37005

Merged
merged 18 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 57 additions & 17 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,63 @@
from __future__ import annotations

import os
from typing import Any, Callable, ClassVar, Iterable, Iterator, Protocol, runtime_checkable
from urllib.parse import urlsplit
import urllib.parse
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, Protocol, runtime_checkable

import attr

if TYPE_CHECKING:
from urllib.parse import SplitResult

__all__ = ["Dataset", "DatasetAll", "DatasetAny"]


def normalize_noop(parts: SplitResult) -> SplitResult:
return parts


def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None:
if scheme == "file":
return normalize_noop
from airflow.providers_manager import ProvidersManager

return ProvidersManager().dataset_uri_handlers.get(scheme)


def _sanitize_uri(uri: str) -> str:
if not uri:
raise ValueError("Dataset URI cannot be empty")
if uri.isspace():
raise ValueError("Dataset URI cannot be just whitespace")
if not uri.isascii():
raise ValueError("Dataset URI must only consist of ASCII characters")
parsed = urllib.parse.urlsplit(uri)
if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
return uri
normalized_scheme = parsed.scheme.lower()
if normalized_scheme.startswith("x-"):
return uri
if normalized_scheme == "airflow":
raise ValueError("Dataset scheme 'airflow' is reserved")
_, auth_exists, normalized_netloc = parsed.netloc.rpartition("@")
if auth_exists:
raise ValueError("Dataset URI must not contain auth information")
if parsed.query:
normalized_query = urllib.parse.urlencode(sorted(urllib.parse.parse_qsl(parsed.query)))
else:
normalized_query = ""
parsed = parsed._replace(
scheme=normalized_scheme,
netloc=normalized_netloc,
path=parsed.path.rstrip("/") or "/", # Remove all trailing slashes.
query=normalized_query,
fragment="", # Ignore any fragments.
)
if (normalizer := _get_uri_normalizer(normalized_scheme)) is not None:
parsed = normalizer(parsed)
return urllib.parse.urlunsplit(parsed)


@runtime_checkable
class BaseDatasetEventInput(Protocol):
"""Protocol for all dataset triggers to use in ``DAG(schedule=...)``.
Expand All @@ -44,33 +93,24 @@ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
class Dataset(os.PathLike, BaseDatasetEventInput):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)])
uri: str = attr.field(
converter=_sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
)
extra: dict[str, Any] | None = None

__version__: ClassVar[int] = 1

@uri.validator
def _check_uri(self, attr, uri: str):
if uri.isspace():
raise ValueError(f"{attr.name} cannot be just whitespace")
try:
uri.encode("ascii")
except UnicodeEncodeError:
raise ValueError(f"{attr.name!r} must be ascii")
parsed = urlsplit(uri)
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError(f"{attr.name!r} scheme `airflow` is reserved")

def __fspath__(self) -> str:
return self.uri

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self.uri == other.uri
else:
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
return hash(self.uri)

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/example_dags/example_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=[
dag1_dataset,
Dataset("s3://this-dataset-doesnt-get-triggered"),
Dataset("s3://unrelated/this-dataset-doesnt-get-triggered"),
],
tags=["consumes", "dataset-scheduled"],
) as dag5:
Expand Down
20 changes: 20 additions & 0 deletions airflow/provider.yaml.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,26 @@
"type": "string"
}
},
"dataset-uris": {
"type": "array",
"description": "Dataset URI formats",
"items": {
"type": "object",
"properties": {
"schemes": {
"type": "array",
"description": "List of supported URI schemes",
"items": {
"type": "string"
}
},
"handler": {
"type": ["string", "null"],
"description": "Normalization function for specified URI schemes. Import path to a callable taking and returning a SplitResult. 'null' specifies a no-op."
}
}
}
},
"transfers": {
"type": "array",
"items": {
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ sensors:
python-modules:
- airflow.providers.amazon.aws.sensors.quicksight

dataset-uris:
- schemes: [s3]
handler: null

filesystems:
- airflow.providers.amazon.aws.fs.s3

Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/google/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
31 changes: 31 additions & 0 deletions airflow/providers/google/datasets/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format bigquery:// must contain a project ID")
if len(uri.path.split("/")) != 3: # Leading slash, database name, and table name.
raise ValueError("URI format bigquery:// must contain dataset and table names")
return uri
6 changes: 6 additions & 0 deletions airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,12 @@ sensors:
filesystems:
- airflow.providers.google.cloud.fs.gcs

dataset-uris:
- schemes: [gcp]
handler: null
- schemes: [bigquery]
handler: airflow.providers.google.datasets.bigquery.sanitize_uri

hooks:
- integration-name: Google Ads
python-modules:
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/mysql/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
34 changes: 34 additions & 0 deletions airflow/providers/mysql/datasets/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format mysql:// must contain a host")
if uri.port is None:
host = uri.netloc.rstrip(":")
uri = uri._replace(netloc=f"{host}:3306")
if len(uri.path.split("/")) != 3: # Leading slash, database name, and table name.
raise ValueError("URI format mysql:// must contain database and table names")
return uri._replace(scheme="mysql")
5 changes: 4 additions & 1 deletion airflow/providers/mysql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ transfers:
target-integration-name: MySQL
python-module: airflow.providers.mysql.transfers.trino_to_mysql


connection-types:
- hook-class-name: airflow.providers.mysql.hooks.mysql.MySqlHook
connection-type: mysql

dataset-uris:
- schemes: [mysql, mariadb]
handler: airflow.providers.mysql.datasets.mysql.sanitize_uri
16 changes: 16 additions & 0 deletions airflow/providers/postgres/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
37 changes: 37 additions & 0 deletions airflow/providers/postgres/datasets/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format postgres:// must contain a host")
if uri.port is None:
host = uri.netloc.rstrip(":")
uri = uri._replace(netloc=f"{host}:5432")
path_parts = uri.path.split("/")
if len(path_parts) != 4: # Leading slash, database, schema, and table names.
raise ValueError("URI format postgres:// must contain database, schema, and table names")
if not path_parts[2]:
path_parts[2] = "default"
uranusjr marked this conversation as resolved.
Show resolved Hide resolved
return uri._replace(scheme="postgres", path="/".join(path_parts))
4 changes: 4 additions & 0 deletions airflow/providers/postgres/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ hooks:
connection-types:
- hook-class-name: airflow.providers.postgres.hooks.postgres.PostgresHook
connection-type: postgres

dataset-uris:
- schemes: [postgres, postgresql]
handler: airflow.providers.postgres.datasets.postgres.sanitize_uri
16 changes: 16 additions & 0 deletions airflow/providers/trino/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
34 changes: 34 additions & 0 deletions airflow/providers/trino/datasets/trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from urllib.parse import SplitResult


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format trino:// must contain a host")
if uri.port is None:
host = uri.netloc.rstrip(":")
uri = uri._replace(netloc=f"{host}:8080")
if len(uri.path.split("/")) != 4: # Leading slash, catalog, schema, and table names.
raise ValueError("URI format trino:// must contain catalog, schema, and table names")
return uri
Loading