Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-6704] Copy common TaskInstance attributes from Task #7324

Merged
merged 6 commits into from Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 22 additions & 13 deletions airflow/models/taskinstance.py
Expand Up @@ -85,6 +85,7 @@ def clear_task_instances(tis,
task_id = ti.task_id
if dag and dag.has_task(task_id):
task = dag.get_task(task_id)
ti.refresh_from_task(task)
task_retries = task.retries
ti.max_tries = ti.try_number + task_retries - 1
else:
Expand Down Expand Up @@ -177,6 +178,7 @@ def __init__(self, task, execution_date, state=None):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.task = task
self.refresh_from_task(task)
self._log = logging.getLogger("airflow.task")

# make sure we have a localized execution_date stored in UTC
Expand All @@ -193,18 +195,11 @@ def __init__(self, task, execution_date, state=None):

self.execution_date = execution_date

self.queue = task.queue
self.pool = task.pool
self.pool_slots = task.pool_slots
self.priority_weight = task.priority_weight_total
self.try_number = 0
self.max_tries = self.task.retries
self.unixname = getpass.getuser()
self.run_as_user = task.run_as_user
if state:
self.state = state
self.hostname = ''
self.executor_config = task.executor_config
self.init_on_load()
# Is this TaskInstance being currently running within `airflow tasks run --raw`.
# Not persisted to the database so only valid for the current process
Expand Down Expand Up @@ -471,6 +466,24 @@ def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_
else:
self.state = None

def refresh_from_task(self, task, pool_override=None):
"""
Copy common attributes from the given task.

:param task: The task object to copy from
:type task: airflow.models.BaseOperator
:param pool_override: Use the pool_override instead of task's pool
:type pool_override: str
"""
self.queue = task.queue
self.pool = pool_override or task.pool
self.pool_slots = task.pool_slots
self.priority_weight = task.priority_weight_total
self.run_as_user = task.run_as_user
self.max_tries = task.retries
self.executor_config = task.executor_config
self.operator = task.__class__.__name__

@provide_session
def clear_xcom_data(self, session=None):
"""
Expand Down Expand Up @@ -772,13 +785,11 @@ def _check_and_change_state_before_execution(
:rtype: bool
"""
task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.refresh_from_task(task, pool_override=pool)
self.test_mode = test_mode
self.refresh_from_db(session=session, lock_for_update=True)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__

if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
Stats.incr('previously_succeeded', 1, 1)
Expand Down Expand Up @@ -888,13 +899,11 @@ def _run_raw_task(
from airflow.sensors.base_sensor_operator import BaseSensorOperator

task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.test_mode = test_mode
self.refresh_from_task(task, pool_override=pool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed here if we call self.refresh_from_db immediately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly trying to preserve the existing behavior and also move some duplicated code into refresh_from_task(). However, you are right that this part is not perfect:

Ideally we should first call refresh_from_db() and then call refresh_from_task(). The call to refresh_from_db() is to load those cumulative values such as self.try_number and self.max_tries from db so that individual runs of the task can increment these numbers. The call to refresh_from_task() is to get those configurable values from the latest DAG definition. However at the moment refresh_from_db() is loading both cumulative values and configurable attributes. So it also sets configurable values such as self.queue and self.operator which are most likely more useful to be read from DAG definition via refresh_task().

This PR is not trying to fix everything. It only consolidate some duplicated code and make attributes such as self.queue and self.pool update-able when tasks are cleared in clear_task_instances(). It's probably worth a separate and bigger PR to make sure refresh_from_db() is only reading those attributes that really should come from db and leave other attributes to refresh_from_task().

self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__

context = {} # type: Dict
actual_start_date = timezone.utcnow()
Expand Down
1 change: 1 addition & 0 deletions tests/api/common/experimental/test_mark_tasks.py
Expand Up @@ -103,6 +103,7 @@ def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=N
self.assertTrue(len(tis) > 0)

for ti in tis: # pylint: disable=too-many-nested-blocks
self.assertEqual(ti.operator, dag.get_task(ti.task_id).__class__.__name__)
if ti.task_id in task_ids and ti.execution_date in execution_dates:
self.assertEqual(ti.state, state)
if state in State.finished():
Expand Down
24 changes: 24 additions & 0 deletions tests/models/test_taskinstance.py
Expand Up @@ -24,6 +24,7 @@
from unittest.mock import mock_open, patch

import pendulum
import pytest
from freezegun import freeze_time
from parameterized import param, parameterized
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -1496,3 +1497,26 @@ def test_handle_failure(self):

context_arg_2 = mock_on_retry_2.call_args[0][0]
assert context_arg_2 and "task_instance" in context_arg_2


@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
task = DummyOperator(task_id="dummy", queue="test_queue", pool="test_pool1", pool_slots=3,
priority_weight=10, run_as_user="test", retries=30,
executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
ti = TI(task, execution_date=pendulum.datetime(2020, 1, 1))
ti.refresh_from_task(task, pool_override=pool_override)

assert ti.queue == task.queue

if pool_override:
assert ti.pool == pool_override
else:
assert ti.pool == task.pool

assert ti.pool_slots == task.pool_slots
assert ti.priority_weight == task.priority_weight_total
assert ti.run_as_user == task.run_as_user
assert ti.max_tries == task.retries
assert ti.executor_config == task.executor_config
assert ti.operator == DummyOperator.__name__