Skip to content

Commit

Permalink
Fix rendering nested task fields (#18516)
Browse files Browse the repository at this point in the history
When we are referring ``{{task}}` in jinja templates we can also
refer some of the fields, which are templated. We are not
able to solve all the problems with such rendering (specifically
recursive rendering of the fields used in JINJA templating might
be problematic. Currently whether you see original, or rendered
field depends solely on the sequence in templated_fields.

However that would not even explain the rendering problem
described in #13559 where kwargs were defined after opargs and
the rendering of opargs **should** work. It turned out that
the problem was with a change introduced in #8805 which made
the context effectively holds a DIFFERENT task than the current
one. Context held an original task, and the curren task was
actually a locked copy of it (to allow resolving upstream
args before locking). As a result, any changes done by
rendering templates were not visible in the task accessed
via {{ task }} jinja variable.

This change replaces the the task stored in context with the
same copy that is then used later during execution so that
at least the "sequential" rendering works and templated
fields which are 'earlier' in the list of templated fields
can be used (and render correctly) in the following fields.

Fixes: #13559
  • Loading branch information
potiuk committed Sep 28, 2021
1 parent b0a2977 commit 1ac63cd
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,9 +1297,8 @@ def _run_raw_task(
:param session: SQLAlchemy ORM Session
:type session: Session
"""
task = self.task
self.test_mode = test_mode
self.refresh_from_task(task, pool_override=pool)
self.refresh_from_task(self.task, pool_override=pool)
self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = get_hostname()
Expand All @@ -1308,11 +1307,12 @@ def _run_raw_task(
session.merge(self)
session.commit()
actual_start_date = timezone.utcnow()
Stats.incr(f'ti.start.{task.dag_id}.{task.task_id}')
Stats.incr(f'ti.start.{self.task.dag_id}.{self.task.task_id}')
try:
if not mark_success:
self.task = self.task.prepare_for_execution()
context = self.get_template_context(ignore_param_exceptions=False)
self._prepare_and_execute_task_with_callbacks(context, task)
self._execute_task_with_callbacks(context)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
self.state = State.SUCCESS
Expand Down Expand Up @@ -1367,7 +1367,7 @@ def _run_raw_task(
session.commit()
raise
finally:
Stats.incr(f'ti.finish.{task.dag_id}.{task.task_id}.{self.state}')
Stats.incr(f'ti.finish.{self.task.dag_id}.{self.task.task_id}.{self.state}')

# Recording SKIPPED or SUCCESS
self.end_date = timezone.utcnow()
Expand All @@ -1379,23 +1379,20 @@ def _run_raw_task(

session.commit()

def _prepare_and_execute_task_with_callbacks(self, context, task):
def _execute_task_with_callbacks(self, context):
"""Prepare Task for Execution"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields

task_copy = task.prepare_for_execution()
self.task = task_copy

def signal_handler(signum, frame):
self.log.error("Received SIGTERM. Terminating subprocesses.")
task_copy.on_kill()
self.task.on_kill()
raise AirflowException("Task received SIGTERM signal")

signal.signal(signal.SIGTERM, signal_handler)

# Don't clear Xcom until the task is certain to execute
self.clear_xcom_data()
with Stats.timer(f'dag.{task_copy.dag_id}.{task_copy.task_id}.duration'):
with Stats.timer(f'dag.{self.task.dag_id}.{self.task.task_id}.duration'):

self.render_templates(context=context)
RenderedTaskInstanceFields.write(RenderedTaskInstanceFields(ti=self, render_templates=False))
Expand All @@ -1411,16 +1408,16 @@ def signal_handler(signum, frame):
os.environ.update(airflow_context_vars)

# Run pre_execute callback
task_copy.pre_execute(context=context)
self.task.pre_execute(context=context)

# Run on_execute callback
self._run_execute_callback(context, task)
self._run_execute_callback(context, self.task)

if task_copy.is_smart_sensor_compatible():
if self.task.is_smart_sensor_compatible():
# Try to register it in the smart sensor service.
registered = False
try:
registered = task_copy.register_in_sensor_service(self, context)
registered = self.task.register_in_sensor_service(self, context)
except Exception:
self.log.warning(
"Failed to register in sensor service."
Expand All @@ -1434,10 +1431,10 @@ def signal_handler(signum, frame):

# Execute the task
with set_current_context(context):
result = self._execute_task(context, task_copy)
result = self._execute_task(context, self.task)

# Run post_execute callback
task_copy.post_execute(context=context, result=result)
self.task.post_execute(context=context, result=result)

Stats.incr(f'operator_successes_{self.task.task_type}', 1, 1)
Stats.incr('ti_successes')
Expand Down

0 comments on commit 1ac63cd

Please sign in to comment.