Skip to content

Commit

Permalink
Move TaskInstanceKey to a separate file (#31033)
Browse files Browse the repository at this point in the history
closes: #30988

Fixes a circular import of DagRun -> TaskInstance -> Sentry ->
KubernetesExexcutor -> kubernetes_helper_functions.py (*) ->
TaskInstanceKey.

Moving TaskInstanceKey out to a separate file will break the cycle from
kubernetes_helper_functions.py -> taskinstances.py

Co-authored-by: Lipu Fei <lipu.fei@kpn.com>
  • Loading branch information
LipuFei and Lipu Fei committed May 4, 2023
1 parent 51603ef commit ac46902
Show file tree
Hide file tree
Showing 29 changed files with 103 additions and 60 deletions.
3 changes: 2 additions & 1 deletion airflow/executors/base_executor.py
Expand Up @@ -38,7 +38,8 @@
if TYPE_CHECKING:
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey

# Command to execute - list of strings
# the first element is always "airflow".
Expand Down
3 changes: 2 additions & 1 deletion airflow/executors/celery_executor.py
Expand Up @@ -56,7 +56,8 @@

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType, EventBufferValueType, TaskTuple
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey

# Task instance that is sent over Celery queues
# TaskInstanceKey, Command, queue_name, CallableTask
Expand Down
3 changes: 2 additions & 1 deletion airflow/executors/celery_kubernetes_executor.py
Expand Up @@ -28,7 +28,8 @@

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType, EventBufferValueType, QueuedTaskInstanceType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey


class CeleryKubernetesExecutor(LoggingMixin):
Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/dask_executor.py
Expand Up @@ -36,7 +36,7 @@

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey


# queue="default" is a special case since this is the base config default queue name,
Expand Down
3 changes: 2 additions & 1 deletion airflow/executors/debug_executor.py
Expand Up @@ -32,7 +32,8 @@
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey


class DebugExecutor(BaseExecutor):
Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/kubernetes_executor.py
Expand Up @@ -54,7 +54,7 @@

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey

# TaskInstance key, command, configuration, pod_template_file
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
Expand Down
3 changes: 2 additions & 1 deletion airflow/executors/local_executor.py
Expand Up @@ -43,7 +43,8 @@

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceKey, TaskInstanceStateType
from airflow.models.taskinstance import TaskInstanceStateType
from airflow.models.taskinstancekey import TaskInstanceKey

# This is a work to be executed by a worker.
# It can Key and Command - but it can also be None, None which is actually a
Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/sequential_executor.py
Expand Up @@ -32,7 +32,7 @@

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey


class SequentialExecutor(BaseExecutor):
Expand Down
2 changes: 1 addition & 1 deletion airflow/kubernetes/kubernetes_helper_functions.py
Expand Up @@ -23,7 +23,7 @@
import pendulum
from slugify import slugify

from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion airflow/models/baseoperator.py
Expand Up @@ -98,7 +98,7 @@
import jinja2 # Slow import.

from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup

Expand Down
37 changes: 2 additions & 35 deletions airflow/models/taskinstance.py
Expand Up @@ -32,7 +32,7 @@
from functools import partial
from pathlib import PurePath
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -92,6 +92,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComAccess, XCom
Expand Down Expand Up @@ -343,40 +344,6 @@ def _is_mappable_value(value: Any) -> TypeGuard[Collection]:
return True


class TaskInstanceKey(NamedTuple):
"""Key used to identify task instance."""

dag_id: str
task_id: str
run_id: str
try_number: int = 1
map_index: int = -1

@property
def primary(self) -> tuple[str, str, str, int]:
"""Return task instance primary key part of the key"""
return self.dag_id, self.task_id, self.run_id, self.map_index

@property
def reduced(self) -> TaskInstanceKey:
"""Remake the key by subtracting 1 from try number to match in memory information"""
return TaskInstanceKey(
self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index
)

def with_try_number(self, try_number: int) -> TaskInstanceKey:
"""Returns TaskInstanceKey with provided ``try_number``"""
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index)

@property
def key(self) -> TaskInstanceKey:
"""For API-compatibly with TaskInstance.
Returns self
"""
return self


def _creator_note(val):
"""Custom creator for the ``note`` association proxy."""
if isinstance(val, str):
Expand Down
54 changes: 54 additions & 0 deletions airflow/models/taskinstancekey.py
@@ -0,0 +1,54 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import NamedTuple


class TaskInstanceKey(NamedTuple):
"""Key used to identify task instance."""

dag_id: str
task_id: str
run_id: str
try_number: int = 1
map_index: int = -1

@property
def primary(self) -> tuple[str, str, str, int]:
"""Return task instance primary key part of the key"""
return self.dag_id, self.task_id, self.run_id, self.map_index

@property
def reduced(self) -> TaskInstanceKey:
"""Remake the key by subtracting 1 from try number to match in memory information"""
return TaskInstanceKey(
self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index
)

def with_try_number(self, try_number: int) -> TaskInstanceKey:
"""Returns TaskInstanceKey with provided ``try_number``"""
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index)

@property
def key(self) -> TaskInstanceKey:
"""For API-compatibly with TaskInstance.
Returns self
"""
return self
2 changes: 1 addition & 1 deletion airflow/models/xcom.py
Expand Up @@ -68,7 +68,7 @@
log = logging.getLogger(__name__)

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey


class BaseXCom(Base, LoggingMixin):
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/trigger_dagrun.py
Expand Up @@ -46,7 +46,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey


class TriggerDagRunLink(BaseOperatorLink):
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/links/base_aws.py
Expand Up @@ -23,7 +23,7 @@

if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/operators/databricks.py
Expand Up @@ -31,7 +31,7 @@
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context

DEFER_METHOD_NAME = "execute_complete"
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/links/base.py
Expand Up @@ -23,7 +23,7 @@

if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey


BASE_LINK = "https://console.cloud.google.com"
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/links/datafusion.py
Expand Up @@ -24,7 +24,7 @@

if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/links/dataproc.py
Expand Up @@ -25,7 +25,7 @@

if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context

DATAPROC_BASE_LINK = BASE_LINK + "/dataproc"
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -52,7 +52,7 @@
)

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context


Expand Down
Expand Up @@ -37,7 +37,7 @@
from airflow.providers.google.common.links.storage import StorageLink

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context


Expand Down
Expand Up @@ -33,7 +33,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/qubole/operators/qubole.py
Expand Up @@ -33,7 +33,7 @@

if TYPE_CHECKING:

from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context


Expand Down
15 changes: 15 additions & 0 deletions docs/apache-airflow/public-airflow-interface.rst
Expand Up @@ -151,6 +151,21 @@ passed to the execute method of the operators via the :class:`~airflow.models.ta
_api/airflow/models/taskinstance/index


Task Instance Keys
------------------

Task instance keys are unique identifiers of task instances in a DAG (in a DAG Run). A key is a tuple that consists of
``dag_id``, ``task_id``, ``run_id``, ``try_number``, and ``map_index``. The key of a task instance can be retrieved via
:meth:`~airflow.models.taskinstance.TaskInstance.key`.

.. toctree::
:includehidden:
:glob:
:maxdepth: 1

_api/airflow/models/taskinstancekey/index


Hooks
-----

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Expand Up @@ -246,6 +246,7 @@ def _get_rst_filepath_from_path(filepath: pathlib.Path):
"dagbag.py",
"param.py",
"taskinstance.py",
"taskinstancekey.py",
"variable.py",
"xcom.py",
}
Expand Down
2 changes: 1 addition & 1 deletion tests/executors/test_kubernetes_executor.py
Expand Up @@ -32,7 +32,7 @@

from airflow import AirflowException
from airflow.exceptions import PodReconciliationError
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.utils import timezone
Expand Down
2 changes: 1 addition & 1 deletion tests/jobs/test_backfill_job.py
Expand Up @@ -45,7 +45,7 @@
from airflow.models import DagBag, Pool, TaskInstance as TI
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
from airflow.operators.empty import EmptyOperator
from airflow.utils import timezone
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_xcom.py
Expand Up @@ -27,7 +27,8 @@

from airflow.configuration import conf
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.settings import json
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/mock_executor.py
Expand Up @@ -21,7 +21,7 @@
from unittest.mock import MagicMock

from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.session import create_session
from airflow.utils.state import State

Expand Down

0 comments on commit ac46902

Please sign in to comment.