Skip to content
This repository has been archived by the owner on May 22, 2021. It is now read-only.

Commit

Permalink
[AIRFLOW-6704] Copy common TaskInstance attributes from Task (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yuqian90 authored and galuszkak committed Mar 5, 2020
1 parent c8b11be commit b750c36
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
35 changes: 22 additions & 13 deletions airflow/models/taskinstance.py
Expand Up @@ -86,6 +86,7 @@ def clear_task_instances(tis,
task_id = ti.task_id
if dag and dag.has_task(task_id):
task = dag.get_task(task_id)
ti.refresh_from_task(task)
task_retries = task.retries
ti.max_tries = ti.try_number + task_retries - 1
else:
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(self, task, execution_date, state=None):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.task = task
self.refresh_from_task(task)
self._log = logging.getLogger("airflow.task")

# make sure we have a localized execution_date stored in UTC
Expand All @@ -194,18 +196,11 @@ def __init__(self, task, execution_date, state=None):

self.execution_date = execution_date

self.queue = task.queue
self.pool = task.pool
self.pool_slots = task.pool_slots
self.priority_weight = task.priority_weight_total
self.try_number = 0
self.max_tries = self.task.retries
self.unixname = getpass.getuser()
self.run_as_user = task.run_as_user
if state:
self.state = state
self.hostname = ''
self.executor_config = task.executor_config
self.init_on_load()
# Is this TaskInstance being currently running within `airflow tasks run --raw`.
# Not persisted to the database so only valid for the current process
Expand Down Expand Up @@ -472,6 +467,24 @@ def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_
else:
self.state = None

def refresh_from_task(self, task, pool_override=None):
"""
Copy common attributes from the given task.
:param task: The task object to copy from
:type task: airflow.models.BaseOperator
:param pool_override: Use the pool_override instead of task's pool
:type pool_override: str
"""
self.queue = task.queue
self.pool = pool_override or task.pool
self.pool_slots = task.pool_slots
self.priority_weight = task.priority_weight_total
self.run_as_user = task.run_as_user
self.max_tries = task.retries
self.executor_config = task.executor_config
self.operator = task.__class__.__name__

@provide_session
def clear_xcom_data(self, session=None):
"""
Expand Down Expand Up @@ -773,13 +786,11 @@ def _check_and_change_state_before_execution(
:rtype: bool
"""
task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.refresh_from_task(task, pool_override=pool)
self.test_mode = test_mode
self.refresh_from_db(session=session, lock_for_update=True)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__

if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
Stats.incr('previously_succeeded', 1, 1)
Expand Down Expand Up @@ -889,13 +900,11 @@ def _run_raw_task(
from airflow.sensors.base_sensor_operator import BaseSensorOperator

task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.test_mode = test_mode
self.refresh_from_task(task, pool_override=pool)
self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__

context = {} # type: Dict
actual_start_date = timezone.utcnow()
Expand Down
1 change: 1 addition & 0 deletions tests/api/common/experimental/test_mark_tasks.py
Expand Up @@ -103,6 +103,7 @@ def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=N
self.assertTrue(len(tis) > 0)

for ti in tis: # pylint: disable=too-many-nested-blocks
self.assertEqual(ti.operator, dag.get_task(ti.task_id).__class__.__name__)
if ti.task_id in task_ids and ti.execution_date in execution_dates:
self.assertEqual(ti.state, state)
if state in State.finished():
Expand Down
24 changes: 24 additions & 0 deletions tests/models/test_taskinstance.py
Expand Up @@ -24,6 +24,7 @@
from unittest.mock import mock_open, patch

import pendulum
import pytest
from freezegun import freeze_time
from parameterized import param, parameterized
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -1496,3 +1497,26 @@ def test_handle_failure(self):

context_arg_2 = mock_on_retry_2.call_args[0][0]
assert context_arg_2 and "task_instance" in context_arg_2


@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
task = DummyOperator(task_id="dummy", queue="test_queue", pool="test_pool1", pool_slots=3,
priority_weight=10, run_as_user="test", retries=30,
executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
ti = TI(task, execution_date=pendulum.datetime(2020, 1, 1))
ti.refresh_from_task(task, pool_override=pool_override)

assert ti.queue == task.queue

if pool_override:
assert ti.pool == pool_override
else:
assert ti.pool == task.pool

assert ti.pool_slots == task.pool_slots
assert ti.priority_weight == task.priority_weight_total
assert ti.run_as_user == task.run_as_user
assert ti.max_tries == task.retries
assert ti.executor_config == task.executor_config
assert ti.operator == DummyOperator.__name__

0 comments on commit b750c36

Please sign in to comment.