Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,13 @@ def error(self, session=None):
session.commit()

@provide_session
def refresh_from_db(self, session=None, lock_for_update=False):
def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_config=False):
"""
Refreshes the task instance from the database based on the primary key

:param refresh_executor_config: if True, revert executor config to
result from DB. Often, however, we will want to keep the newest
version
:param lock_for_update: if True, indicates that the database should
lock the TaskInstance (issuing a FOR UPDATE clause) until the
session is committed.
Expand All @@ -454,7 +457,8 @@ def refresh_from_db(self, session=None, lock_for_update=False):
self.max_tries = ti.max_tries
self.hostname = ti.hostname
self.pid = ti.pid
self.executor_config = ti.executor_config
if refresh_executor_config:
self.executor_config = ti.executor_config
else:
self.state = None

Expand Down
27 changes: 27 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from airflow.utils.state import State
from tests.models import DEFAULT_DATE
from tests.test_utils import db
from airflow.utils.db import provide_session


class TestTaskInstance(unittest.TestCase):
Expand Down Expand Up @@ -343,6 +344,32 @@ def test_run_pooling_task(self):
db.clear_db_pools()
self.assertEqual(ti.state, State.SUCCESS)

@provide_session
def test_ti_updates_with_task(self, session=None):
"""
test that updating the executor_config propogates to the TaskInstance DB
"""
dag = models.DAG(dag_id='test_run_pooling_task')
task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
executor_config={'foo': 'bar'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=timezone.utcnow())

ti.run(session=session)
tis = dag.get_task_instances()
self.assertEqual({'foo': 'bar'}, tis[0].executor_config)

task2 = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
executor_config={'bar': 'baz'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))

ti = TI(
task=task2, execution_date=timezone.utcnow())
ti.run(session=session)
tis = dag.get_task_instances()
self.assertEqual({'bar': 'baz'}, tis[1].executor_config)

def test_run_pooling_task_with_mark_success(self):
"""
test that running task in an existing pool with mark_success param
Expand Down