Skip to content

Commit

Permalink
[AIRFLOW-6704] Copy common TaskInstance attributes from Task (apache#…
Browse files Browse the repository at this point in the history
…7324)

cherry-picked from 3ef7118
  • Loading branch information
yuqian90 authored and kaxil committed Mar 30, 2020
1 parent ca9352f commit eb96c13
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 14 deletions.
35 changes: 22 additions & 13 deletions airflow/models/taskinstance.py
Expand Up @@ -91,6 +91,7 @@ def clear_task_instances(tis,
task_id = ti.task_id
if dag and dag.has_task(task_id):
task = dag.get_task(task_id)
ti.refresh_from_task(task)
task_retries = task.retries
ti.max_tries = ti.try_number + task_retries - 1
else:
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(self, task, execution_date, state=None):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.task = task
self.refresh_from_task(task)
self._log = logging.getLogger("airflow.task")

# make sure we have a localized execution_date stored in UTC
Expand All @@ -194,18 +196,11 @@ def __init__(self, task, execution_date, state=None):

self.execution_date = execution_date

self.queue = task.queue
self.pool = task.pool
self.pool_slots = task.pool_slots
self.priority_weight = task.priority_weight_total
self.try_number = 0
self.max_tries = self.task.retries
self.unixname = getpass.getuser()
self.run_as_user = task.run_as_user
if state:
self.state = state
self.hostname = ''
self.executor_config = task.executor_config
self.init_on_load()
# Is this TaskInstance being currently running within `airflow run --raw`.
# Not persisted to the database so only valid for the current process
Expand Down Expand Up @@ -500,6 +495,24 @@ def refresh_from_db(self, session=None, lock_for_update=False):
else:
self.state = None

def refresh_from_task(self, task, pool_override=None):
"""
Copy common attributes from the given task.
:param task: The task object to copy from
:type task: airflow.models.BaseOperator
:param pool_override: Use the pool_override instead of task's pool
:type pool_override: str
"""
self.queue = task.queue
self.pool = pool_override or task.pool
self.pool_slots = task.pool_slots
self.priority_weight = task.priority_weight_total
self.run_as_user = task.run_as_user
self.max_tries = task.retries
self.executor_config = task.executor_config
self.operator = task.__class__.__name__

@provide_session
def clear_xcom_data(self, session=None):
"""
Expand Down Expand Up @@ -798,13 +811,11 @@ def _check_and_change_state_before_execution(
:rtype: bool
"""
task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.refresh_from_task(task, pool_override=pool)
self.test_mode = test_mode
self.refresh_from_db(session=session, lock_for_update=True)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__

if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
Stats.incr('previously_succeeded', 1, 1)
Expand Down Expand Up @@ -915,13 +926,11 @@ def _run_raw_task(
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF

task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.test_mode = test_mode
self.refresh_from_task(task, pool_override=pool)
self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__

context = {}
actual_start_date = timezone.utcnow()
Expand Down
3 changes: 2 additions & 1 deletion tests/api/common/experimental/test_mark_tasks.py
Expand Up @@ -102,7 +102,8 @@ def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=N

self.assertTrue(len(tis) > 0)

for ti in tis:
for ti in tis: # pylint: disable=too-many-nested-blocks
self.assertEqual(ti.operator, dag.get_task(ti.task_id).__class__.__name__)
if ti.task_id in task_ids and ti.execution_date in execution_dates:
self.assertEqual(ti.state, state)
else:
Expand Down
24 changes: 24 additions & 0 deletions tests/models/test_taskinstance.py
Expand Up @@ -23,6 +23,7 @@
import urllib
from typing import Union, List
import pendulum
import pytest
from freezegun import freeze_time
from mock import patch, mock_open
from parameterized import parameterized, param
Expand Down Expand Up @@ -1494,3 +1495,26 @@ def test_get_rendered_template_fields(self, store_serialized_dag):
# CleanUp
with create_session() as session:
session.query(RenderedTaskInstanceFields).delete()


@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
task = DummyOperator(task_id="dummy", queue="test_queue", pool="test_pool1", pool_slots=3,
priority_weight=10, run_as_user="test", retries=30,
executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
ti = TI(task, execution_date=pendulum.datetime(2020, 1, 1))
ti.refresh_from_task(task, pool_override=pool_override)

assert ti.queue == task.queue

if pool_override:
assert ti.pool == pool_override
else:
assert ti.pool == task.pool

assert ti.pool_slots == task.pool_slots
assert ti.priority_weight == task.priority_weight_total
assert ti.run_as_user == task.run_as_user
assert ti.max_tries == task.retries
assert ti.executor_config == task.executor_config
assert ti.operator == DummyOperator.__name__

0 comments on commit eb96c13

Please sign in to comment.