diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 01b75c4f1e76c..23a135b206d58 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1297,9 +1297,8 @@ def _run_raw_task( :param session: SQLAlchemy ORM Session :type session: Session """ - task = self.task self.test_mode = test_mode - self.refresh_from_task(task, pool_override=pool) + self.refresh_from_task(self.task, pool_override=pool) self.refresh_from_db(session=session) self.job_id = job_id self.hostname = get_hostname() @@ -1308,11 +1307,12 @@ def _run_raw_task( session.merge(self) session.commit() actual_start_date = timezone.utcnow() - Stats.incr(f'ti.start.{task.dag_id}.{task.task_id}') + Stats.incr(f'ti.start.{self.task.dag_id}.{self.task.task_id}') try: if not mark_success: + self.task = self.task.prepare_for_execution() context = self.get_template_context(ignore_param_exceptions=False) - self._prepare_and_execute_task_with_callbacks(context, task) + self._execute_task_with_callbacks(context) if not test_mode: self.refresh_from_db(lock_for_update=True, session=session) self.state = State.SUCCESS @@ -1367,7 +1367,7 @@ def _run_raw_task( session.commit() raise finally: - Stats.incr(f'ti.finish.{task.dag_id}.{task.task_id}.{self.state}') + Stats.incr(f'ti.finish.{self.task.dag_id}.{self.task.task_id}.{self.state}') # Recording SKIPPED or SUCCESS self.end_date = timezone.utcnow() @@ -1379,23 +1379,20 @@ def _run_raw_task( session.commit() - def _prepare_and_execute_task_with_callbacks(self, context, task): + def _execute_task_with_callbacks(self, context): """Prepare Task for Execution""" from airflow.models.renderedtifields import RenderedTaskInstanceFields - task_copy = task.prepare_for_execution() - self.task = task_copy - def signal_handler(signum, frame): self.log.error("Received SIGTERM. Terminating subprocesses.") - task_copy.on_kill() + self.task.on_kill() raise AirflowException("Task received SIGTERM signal") signal.signal(signal.SIGTERM, signal_handler) # Don't clear Xcom until the task is certain to execute self.clear_xcom_data() - with Stats.timer(f'dag.{task_copy.dag_id}.{task_copy.task_id}.duration'): + with Stats.timer(f'dag.{self.task.dag_id}.{self.task.task_id}.duration'): self.render_templates(context=context) RenderedTaskInstanceFields.write(RenderedTaskInstanceFields(ti=self, render_templates=False)) @@ -1411,16 +1408,16 @@ def signal_handler(signum, frame): os.environ.update(airflow_context_vars) # Run pre_execute callback - task_copy.pre_execute(context=context) + self.task.pre_execute(context=context) # Run on_execute callback - self._run_execute_callback(context, task) + self._run_execute_callback(context, self.task) - if task_copy.is_smart_sensor_compatible(): + if self.task.is_smart_sensor_compatible(): # Try to register it in the smart sensor service. registered = False try: - registered = task_copy.register_in_sensor_service(self, context) + registered = self.task.register_in_sensor_service(self, context) except Exception: self.log.warning( "Failed to register in sensor service." @@ -1434,10 +1431,10 @@ def signal_handler(signum, frame): # Execute the task with set_current_context(context): - result = self._execute_task(context, task_copy) + result = self._execute_task(context, self.task) # Run post_execute callback - task_copy.post_execute(context=context, result=result) + self.task.post_execute(context=context, result=result) Stats.incr(f'operator_successes_{self.task.task_type}', 1, 1) Stats.incr('ti_successes')