Skip to content

Commit

Permalink
Implement Metadata to emit runtime extra (apache#38650)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored and utkarsharma2 committed Apr 22, 2024
1 parent 43a6729 commit e47f3f2
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 45 deletions.
24 changes: 20 additions & 4 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@


def normalize_noop(parts: SplitResult) -> SplitResult:
"""Place-hold a :class:`~urllib.parse.SplitResult`` normalizer.
:meta private:
"""
return parts


Expand All @@ -42,13 +46,11 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N
return ProvidersManager().dataset_uri_handlers.get(scheme)


def sanitize_uri(uri: str) -> str:
def _sanitize_uri(uri: str) -> str:
"""Sanitize a dataset URI.
This checks for URI validity, and normalizes the URI if needed. A fully
normalized URI is returned.
:meta private:
"""
if not uri:
raise ValueError("Dataset URI cannot be empty")
Expand Down Expand Up @@ -89,6 +91,20 @@ def sanitize_uri(uri: str) -> str:
return urllib.parse.urlunsplit(parsed)


def coerce_to_uri(value: str | Dataset) -> str:
"""Coerce a user input into a sanitized URI.
If the input value is a string, it is treated as a URI and sanitized. If the
input is a :class:`Dataset`, the URI it contains is considered sanitized and
returned directly.
:meta private:
"""
if isinstance(value, Dataset):
return value.uri
return _sanitize_uri(str(value))


class BaseDatasetEventInput:
"""Protocol for all dataset triggers to use in ``DAG(schedule=...)``.
Expand Down Expand Up @@ -127,7 +143,7 @@ class Dataset(os.PathLike, BaseDatasetEventInput):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(
converter=sanitize_uri,
converter=_sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
)
extra: dict[str, Any] | None = None
Expand Down
39 changes: 39 additions & 0 deletions airflow/datasets/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import attrs

from airflow.datasets import coerce_to_uri

if TYPE_CHECKING:
from airflow.datasets import Dataset


@attrs.define(init=False)
class Metadata:
"""Metadata to attach to a DatasetEvent."""

uri: str
extra: dict[str, Any]

def __init__(self, target: str | Dataset, extra: dict[str, Any]) -> None:
self.uri = coerce_to_uri(target)
self.extra = extra
25 changes: 18 additions & 7 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
import contextlib
import copy
import functools
import inspect
import logging
import sys
import warnings
from datetime import datetime, timedelta
from functools import total_ordering, wraps
from inspect import signature
from types import FunctionType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -91,10 +91,11 @@
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.context import Context, context_get_dataset_events
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
Expand Down Expand Up @@ -423,7 +424,7 @@ def _apply_defaults(cls, func: T) -> T:
# at every decorated invocation. This is separate sig_cache created
# per decoration, i.e. each function decorated using apply_defaults will
# have a different sig_cache.
sig_cache = signature(func)
sig_cache = inspect.signature(func)
non_variadic_params = {
name: param
for (name, param) in sig_cache.parameters.items()
Expand Down Expand Up @@ -1269,8 +1270,13 @@ def set_xcomargs_dependencies(self) -> None:
@prepare_lineage
def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
if self._pre_execute_hook is not None:
self._pre_execute_hook(context)
if self._pre_execute_hook is None:
return
ExecutionCallableRunner(
self._pre_execute_hook,
context_get_dataset_events(context),
logger=self.log,
).run(context)

def execute(self, context: Context) -> Any:
"""
Expand All @@ -1289,8 +1295,13 @@ def post_execute(self, context: Any, result: Any = None):
It is passed the execution context and any results returned by the operator.
"""
if self._post_execute_hook is not None:
self._post_execute_hook(context, result)
if self._post_execute_hook is None:
return
ExecutionCallableRunner(
self._post_execute_hook,
context_get_dataset_events(context),
logger=self.log,
).run(context, result)

def on_kill(self) -> None:
"""
Expand Down
40 changes: 23 additions & 17 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,15 @@
Context,
DatasetEventAccessors,
VariableAccessor,
context_get_dataset_events,
context_merge,
)
from airflow.utils.email import send_email
from airflow.utils.helpers import prune_dict, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import qualname
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -432,12 +433,16 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
if execute_callable.__name__ == "execute":
execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel

def _execute_callable(context, **execute_callable_kwargs):
def _execute_callable(context: Context, **execute_callable_kwargs):
try:
# Print a marker for log grouping of details before task execution
log.info("::endgroup::")

return execute_callable(context=context, **execute_callable_kwargs)
return ExecutionCallableRunner(
execute_callable,
context_get_dataset_events(context),
logger=log,
).run(context=context, **execute_callable_kwargs)
except SystemExit as e:
# Handle only successful cases here. Failure cases will be handled upper
# in the exception chain.
Expand Down Expand Up @@ -2678,6 +2683,10 @@ def signal_handler(signum, frame):
jinja_env = None
task_orig = self.render_templates(context=context, jinja_env=jinja_env)

# The task is never MappedOperator at this point.
if TYPE_CHECKING:
assert isinstance(self.task, BaseOperator)

if not test_mode:
rendered_fields = get_serialized_template_fields(task=self.task)
_update_rtif(ti=self, rendered_fields=rendered_fields)
Expand All @@ -2695,8 +2704,7 @@ def signal_handler(signum, frame):
)

# Run pre_execute callback
# Is never MappedOperator at this point
self.task.pre_execute(context=context) # type: ignore[union-attr]
self.task.pre_execute(context=context)

# Run on_execute callback
self._run_execute_callback(context, self.task)
Expand All @@ -2711,8 +2719,7 @@ def signal_handler(signum, frame):
result = self._execute_task(context, task_orig)

# Run post_execute callback
# Is never MappedOperator at this point
self.task.post_execute(context=context, result=result) # type: ignore[union-attr]
self.task.post_execute(context=context, result=result)

# DAG authors define map_index_template at the task level
if jinja_env is not None and (template := context.get("map_index_template")) is not None:
Expand All @@ -2724,7 +2731,7 @@ def signal_handler(signum, frame):
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
Stats.incr("ti_successes", tags=self.stats_tags)

def _execute_task(self, context, task_orig):
def _execute_task(self, context: Context, task_orig: Operator):
"""
Execute Task (optionally with a Timeout) and push Xcom results.
Expand Down Expand Up @@ -2775,16 +2782,15 @@ def defer_task(self, session: Session, defer: TaskDeferred) -> None:
else:
self.trigger_timeout = self.start_date + execution_timeout

def _run_execute_callback(self, context: Context, task: Operator) -> None:
def _run_execute_callback(self, context: Context, task: BaseOperator) -> None:
"""Functions that need to be run before a Task is executed."""
callbacks = task.on_execute_callback
if callbacks:
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
for callback in callbacks:
try:
callback(context)
except Exception:
self.log.exception("Failed when executing execute callback")
if not (callbacks := task.on_execute_callback):
return
for callback in callbacks if isinstance(callbacks, list) else [callbacks]:
try:
callback(context)
except Exception:
self.log.exception("Failed when executing execute callback")

@provide_session
def run(
Expand Down
12 changes: 8 additions & 4 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.context import context_copy_partial, context_get_dataset_events, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script

Expand Down Expand Up @@ -231,6 +231,7 @@ def __init__(
def execute(self, context: Context) -> Any:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
self._dataset_events = context_get_dataset_events(context)

return_value = self.execute_callable()
if self.show_return_value_in_logs:
Expand All @@ -249,7 +250,8 @@ def execute_callable(self) -> Any:
:return: the return value of the call.
"""
return self.python_callable(*self.op_args, **self.op_kwargs)
runner = ExecutionCallableRunner(self.python_callable, self._dataset_events, logger=self.log)
return runner.run(*self.op_args, **self.op_kwargs)


class BranchPythonOperator(PythonOperator, BranchMixIn):
Expand Down Expand Up @@ -406,7 +408,9 @@ def __init__(
or isinstance(python_callable, types.LambdaType)
and python_callable.__name__ == "<lambda>"
):
raise AirflowException("PythonVirtualenvOperator only supports functions for python_callable arg")
raise ValueError(f"{type(self).__name__} only supports functions for python_callable arg")
if inspect.isgeneratorfunction(python_callable):
raise ValueError(f"{type(self).__name__} does not support using 'yield' in python_callable")
super().__init__(
python_callable=python_callable,
op_args=op_args,
Expand Down
17 changes: 9 additions & 8 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import attrs
import lazy_object_proxy

from airflow.datasets import Dataset, sanitize_uri
from airflow.datasets import Dataset, coerce_to_uri
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.utils.types import NOTSET

Expand Down Expand Up @@ -169,13 +169,7 @@ def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor:
if isinstance(key, str):
uri = sanitize_uri(key)
elif isinstance(key, Dataset):
uri = key.uri
else:
return NotImplemented
if uri not in self._dict:
if (uri := coerce_to_uri(key)) not in self._dict:
self._dict[uri] = DatasetEventAccessor({})
return self._dict[uri]

Expand Down Expand Up @@ -361,3 +355,10 @@ def _create_value(k: str, v: Any) -> Any:
return lazy_object_proxy.Proxy(factory)

return {k: _create_value(k, v) for k, v in source._context.items()}


def context_get_dataset_events(context: Context) -> DatasetEventAccessors:
try:
return context["dataset_events"]
except KeyError:
return DatasetEventAccessors()
1 change: 1 addition & 0 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,4 @@ def context_merge(context: Context, **kwargs: Any) -> None: ...
def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: ...
def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
def context_get_dataset_events(context: Context) -> DatasetEventAccessors: ...

0 comments on commit e47f3f2

Please sign in to comment.