Skip to content

Commit

Permalink
Merge branch 'apache:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee2532 committed May 23, 2024
2 parents 48022d6 + 61e9070 commit be291d1
Show file tree
Hide file tree
Showing 94 changed files with 2,809 additions and 617 deletions.
6 changes: 5 additions & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,12 @@ function check_boto_upgrade() {
${PACKAGING_TOOL_CMD} uninstall ${EXTRA_UNINSTALL_FLAGS} aiobotocore s3fs || true
# We need to include oss2 as dependency as otherwise jmespath will be bumped and it will not pass
# the pip check test, Similarly gcloud-aio-auth limit is needed to be included as it bumps cryptography
# Also until docker-py compatibility with requests 2.32 is fixed we need to limit requests version
# Should be removed after https://github.com/docker/docker-py/issues/3256 together with removal of similar
# limitation in providers/docker/pyproject.toml
# shellcheck disable=SC2086
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade boto3 botocore "oss2>=2.14.0" "gcloud-aio-auth>=4.0.0,<5.0.0"
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade boto3 botocore \
"oss2>=2.14.0" "gcloud-aio-auth>=4.0.0,<5.0.0" "requests<2.32.0"
pip check
}

Expand Down
3 changes: 1 addition & 2 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from airflow.exceptions import TaskNotFound
from airflow.models import SlaMiss
from airflow.models.dagrun import DagRun as DR
from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
Expand Down Expand Up @@ -201,7 +200,7 @@ def get_mapped_task_instances(
except TaskNotFound:
error_message = f"Task id {task_id} not found"
raise NotFound(error_message)
if not needs_expansion(task):
if not task.get_needs_expansion():
error_message = f"Task id {task_id} is not mapped"
raise NotFound(error_message)

Expand Down
11 changes: 5 additions & 6 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from airflow.models import DagPickle, TaskInstance
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
from airflow.models.operator import needs_expansion
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
Expand Down Expand Up @@ -177,7 +176,7 @@ def _get_ti_db_access(

if not exec_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
if needs_expansion(task):
if task.get_needs_expansion():
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
elif map_index >= 0:
Expand Down Expand Up @@ -228,10 +227,10 @@ def _get_ti(
pool=pool,
create_if_necessary=create_if_necessary,
)
# setting ti.task is necessary for AIP-44 since the task object does not serialize perfectly
# if we update the serialization logic for Operator to also serialize the dag object on it,
# then this would not be necessary;
ti.task = task

# we do refresh_from_task so that if TI has come back via RPC, we ensure that ti.task
# is the original task object and not the result of the round trip
ti.refresh_from_task(task, pool_override=pool)
return ti, dr_created


Expand Down
15 changes: 7 additions & 8 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,13 @@ def _sanitize_uri(uri: str) -> str:
parsed = normalizer(parsed)
except ValueError as exception:
if conf.getboolean("core", "strict_dataset_uri_validation", fallback=False):
raise exception
else:
warnings.warn(
f"The dataset URI {uri} is not AIP-60 compliant. "
f"In Airflow 3, this will raise an exception. More information: {repr(exception)}",
UserWarning,
stacklevel=3,
)
raise
warnings.warn(
f"The dataset URI {uri} is not AIP-60 compliant: {exception}. "
f"In Airflow 3, this will raise an exception.",
UserWarning,
stacklevel=3,
)
return urllib.parse.urlunsplit(parsed)


Expand Down
4 changes: 2 additions & 2 deletions airflow/example_dags/example_branch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def branch_with_venv(choices):

branching_venv = BranchPythonVirtualenvOperator(
task_id="branching_venv",
requirements=["numpy~=1.24.4"],
requirements=["numpy~=1.26.0"],
venv_cache_path=VENV_CACHE_PATH,
python_callable=branch_with_venv,
op_args=[options],
Expand All @@ -162,7 +162,7 @@ def hello_world_with_venv():
for option in options:
t = PythonVirtualenvOperator(
task_id=f"venv_{option}",
requirements=["numpy~=1.24.4"],
requirements=["numpy~=1.26.0"],
venv_cache_path=VENV_CACHE_PATH,
python_callable=hello_world_with_venv,
)
Expand Down
15 changes: 14 additions & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AbstractOperator(Templater, DAGNode):
outlets: list
inlets: list
trigger_rule: TriggerRule

_needs_expansion: bool | None = None
_on_failure_fail_dagrun = False

HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
Expand Down Expand Up @@ -395,6 +395,19 @@ def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
"""
return next(self.iter_mapped_task_groups(), None)

def get_needs_expansion(self) -> bool:
"""
Return true if the task is MappedOperator or is in a mapped task group.
:meta private:
"""
if self._needs_expansion is None:
if self.get_closest_mapped_task_group() is not None:
self._needs_expansion = True
else:
self._needs_expansion = False
return self._needs_expansion

def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
"""Get the "normal" operator from current abstract operator.
Expand Down
4 changes: 3 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Callable,
Collection,
Iterable,
NoReturn,
Sequence,
TypeVar,
Union,
Expand Down Expand Up @@ -1680,6 +1681,7 @@ def get_serialized_fields(cls):
"map_index_template",
"start_trigger",
"next_method",
"_needs_expansion",
}
)
DagContext.pop_context_managed_dag()
Expand All @@ -1705,7 +1707,7 @@ def defer(
method_name: str,
kwargs: dict[str, Any] | None = None,
timeout: timedelta | None = None,
):
) -> NoReturn:
"""
Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
Expand Down
1 change: 1 addition & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class MappedOperator(AbstractOperator):
_operator_name: str
start_trigger: BaseTrigger | None
next_method: str | None
_needs_expansion: bool = True

dag: DAG | None
task_group: TaskGroup | None
Expand Down
26 changes: 2 additions & 24 deletions airflow/models/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,12 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Union
from typing import Union

from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator

if TYPE_CHECKING:
from airflow.models.abstractoperator import AbstractOperator
from airflow.typing_compat import TypeGuard

Operator = Union[BaseOperator, MappedOperator]


def needs_expansion(task: AbstractOperator) -> TypeGuard[Operator]:
"""Whether a task needs expansion at runtime.
A task needs expansion if it either
* Is a mapped operator, or
* Is in a mapped task group.
This is implemented as a free function (instead of a property) so we can
make it a type guard.
"""
if isinstance(task, MappedOperator):
return True
if task.get_closest_mapped_task_group() is not None:
return True
return False


__all__ = ["Operator", "needs_expansion"]
__all__ = ["Operator"]
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def attempt_task_runs(self):
except AttributeError:
# running_state is newly added, and only needed to support task adoption (an optional
# executor feature).
# TODO: remove when min airflow version >= 2.9.2
pass
if failure_reasons:
self.log.error(
Expand Down
37 changes: 36 additions & 1 deletion airflow/providers/amazon/aws/hooks/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class NeptuneHook(AwsBaseHook):

AVAILABLE_STATES = ["available"]
STOPPED_STATES = ["stopped"]
ERROR_STATES = [
"cloning-failed",
"inaccessible-encryption-credentials",
"inaccessible-encryption-credentials-recoverable",
"migration-failed",
]

def __init__(self, *args, **kwargs):
kwargs["client_type"] = "neptune"
Expand Down Expand Up @@ -82,4 +88,33 @@ def get_cluster_status(self, cluster_id: str) -> str:
:param cluster_id: The ID of the cluster to get the status of.
:return: The status of the cluster.
"""
return self.get_conn().describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]
return self.conn.describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]

def get_db_instance_status(self, instance_id: str) -> str:
"""
Get the status of a Neptune instance.
:param instance_id: The ID of the instance to get the status of.
:return: The status of the instance.
"""
return self.conn.describe_db_instances(DBInstanceIdentifier=instance_id)["DBInstances"][0][
"DBInstanceStatus"
]

def wait_for_cluster_instance_availability(
self, cluster_id: str, delay: int = 30, max_attempts: int = 60
) -> None:
"""
Wait for Neptune instances in a cluster to be available.
:param cluster_id: The cluster ID of the instances to wait for.
:param delay: Time in seconds to delay between polls.
:param max_attempts: Maximum number of attempts to poll for completion.
:return: The status of the instances.
"""
filters = [{"Name": "db-cluster-id", "Values": [cluster_id]}]
self.log.info("Waiting for instances in cluster %s.", cluster_id)
self.get_waiter("db_instance_available").wait(
Filters=filters, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
)
self.log.info("Finished waiting for instances in cluster %s.", cluster_id)
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,14 @@ def execute(self, context: Context) -> str | None:
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_polling_attempts,
)
if self.max_polling_attempts
else EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down
Loading

0 comments on commit be291d1

Please sign in to comment.