diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index bc284e5e8d0db..3d4a9082c5edf 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -56,6 +56,7 @@ from airflow.utils.helpers import is_container from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname +from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.session import provide_session from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import State @@ -915,6 +916,12 @@ def signal_handler(signum, frame): start_time = time.time() self.render_templates(context=context) + # Export context to make it available for operators to use. + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info("Exporting the following env vars:\n%s", + '\n'.join(["{}={}".format(k, v) + for k, v in airflow_context_vars.items()])) + os.environ.update(airflow_context_vars) task_copy.pre_execute(context=context) try: diff --git a/airflow/operators/python.py b/airflow/operators/python.py index ff82f4d15d98e..9699fa8b17c8e 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -32,7 +32,6 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator, SkipMixin from airflow.utils.decorators import apply_defaults -from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.process_utils import execute_in_subprocess from airflow.utils.python_virtualenv import prepare_virtualenv @@ -127,13 +126,6 @@ def determine_op_kwargs(python_callable: Callable, return op_kwargs def execute(self, context: Dict): - # Export context to make it available for callables to use. - airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - self.log.debug("Exporting the following env vars:\n%s", - '\n'.join(["{}={}".format(k, v) - for k, v in airflow_context_vars.items()])) - os.environ.update(airflow_context_vars) - context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 28339a920a959..3142ce1d2051f 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -17,6 +17,7 @@ # under the License. import datetime +import os import time import unittest import urllib @@ -1220,9 +1221,9 @@ def test_set_duration_empty_dates(self): ti.set_duration() self.assertIsNone(ti.duration) - def test_success_callbak_no_race_condition(self): + def test_success_callback_no_race_condition(self): callback_wrapper = CallbackWrapper() - dag = DAG('test_success_callbak_no_race_condition', start_date=DEFAULT_DATE, + dag = DAG('test_success_callback_no_race_condition', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task = DummyOperator(task_id='op', email='test@test.test', on_success_callback=callback_wrapper.success_handler, dag=dag) @@ -1445,7 +1446,7 @@ def on_execute_callable(context): 'test_dagrun_execute_callback' ) - dag = DAG('test_execute_callbak', start_date=DEFAULT_DATE, + dag = DAG('test_execute_callback', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task = DummyOperator(task_id='op', email='test@test.test', on_execute_callback=on_execute_callable, @@ -1496,6 +1497,35 @@ def test_handle_failure(self): context_arg_2 = mock_on_retry_2.call_args[0][0] assert context_arg_2 and "task_instance" in context_arg_2 + def _env_var_check_callback(self): + self.assertEqual('test_echo_env_variables', os.environ['AIRFLOW_CTX_DAG_ID']) + self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID']) + self.assertEqual(DEFAULT_DATE.isoformat(), + os.environ['AIRFLOW_CTX_EXECUTION_DATE']) + self.assertEqual('manual__' + DEFAULT_DATE.isoformat(), + os.environ['AIRFLOW_CTX_DAG_RUN_ID']) + + def test_echo_env_variables(self): + dag = DAG('test_echo_env_variables', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + op = PythonOperator(task_id='hive_in_python_op', + dag=dag, + python_callable=self._env_var_check_callback) + dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False) + ti = TI(task=op, execution_date=DEFAULT_DATE) + ti.state = State.RUNNING + session = settings.Session() + session.merge(ti) + session.commit() + ti._run_raw_task() + ti.refresh_from_db() + self.assertEqual(ti.state, State.SUCCESS) + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) def test_refresh_from_task(pool_override): diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index aa3100cc17863..ab2313cfc1eae 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -19,7 +19,6 @@ import copy import datetime import logging -import os import sys import unittest import unittest.mock @@ -120,12 +119,6 @@ def _assert_calls_equal(self, first, second): self.assertDictEqual(first.kwargs, second.kwargs) -@unittest.mock.patch('os.environ', { - 'AIRFLOW_CTX_DAG_ID': None, - 'AIRFLOW_CTX_TASK_ID': None, - 'AIRFLOW_CTX_EXECUTION_DATE': None, - 'AIRFLOW_CTX_DAG_RUN_ID': None -}) class TestPythonOperator(TestPythonBase): def do_run(self): @@ -249,32 +242,6 @@ def test_python_operator_shallow_copy_attr(self): self.assertEqual(id(original_task.python_callable), id(new_task.python_callable)) - def _env_var_check_callback(self): - self.assertEqual('test_dag', os.environ['AIRFLOW_CTX_DAG_ID']) - self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID']) - self.assertEqual(DEFAULT_DATE.isoformat(), - os.environ['AIRFLOW_CTX_EXECUTION_DATE']) - self.assertEqual('manual__' + DEFAULT_DATE.isoformat(), - os.environ['AIRFLOW_CTX_DAG_RUN_ID']) - - def test_echo_env_variables(self): - """ - Test that env variables are exported correctly to the - python callback in the task. - """ - self.dag.create_dagrun( - run_id='manual__' + DEFAULT_DATE.isoformat(), - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - external_trigger=False, - ) - - op = PythonOperator(task_id='hive_in_python_op', - dag=self.dag, - python_callable=self._env_var_check_callback) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_conflicting_kwargs(self): self.dag.create_dagrun( run_id='manual__' + DEFAULT_DATE.isoformat(),