Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
Expand Down Expand Up @@ -915,6 +916,12 @@ def signal_handler(signum, frame):
start_time = time.time()

self.render_templates(context=context)
# Export context to make it available for operators to use.
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
self.log.info("Exporting the following env vars:\n%s",
'\n'.join(["{}={}".format(k, v)
for k, v in airflow_context_vars.items()]))
os.environ.update(airflow_context_vars)
task_copy.pre_execute(context=context)

try:
Expand Down
8 changes: 0 additions & 8 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, SkipMixin
from airflow.utils.decorators import apply_defaults
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv

Expand Down Expand Up @@ -127,13 +126,6 @@ def determine_op_kwargs(python_callable: Callable,
return op_kwargs

def execute(self, context: Dict):
# Export context to make it available for callables to use.
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
self.log.debug("Exporting the following env vars:\n%s",
'\n'.join(["{}={}".format(k, v)
for k, v in airflow_context_vars.items()]))
os.environ.update(airflow_context_vars)

context.update(self.op_kwargs)
context['templates_dict'] = self.templates_dict

Expand Down
36 changes: 33 additions & 3 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import datetime
import os
import time
import unittest
import urllib
Expand Down Expand Up @@ -1220,9 +1221,9 @@ def test_set_duration_empty_dates(self):
ti.set_duration()
self.assertIsNone(ti.duration)

def test_success_callbak_no_race_condition(self):
def test_success_callback_no_race_condition(self):
callback_wrapper = CallbackWrapper()
dag = DAG('test_success_callbak_no_race_condition', start_date=DEFAULT_DATE,
dag = DAG('test_success_callback_no_race_condition', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task = DummyOperator(task_id='op', email='test@test.test',
on_success_callback=callback_wrapper.success_handler, dag=dag)
Expand Down Expand Up @@ -1445,7 +1446,7 @@ def on_execute_callable(context):
'test_dagrun_execute_callback'
)

dag = DAG('test_execute_callbak', start_date=DEFAULT_DATE,
dag = DAG('test_execute_callback', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task = DummyOperator(task_id='op', email='test@test.test',
on_execute_callback=on_execute_callable,
Expand Down Expand Up @@ -1496,6 +1497,35 @@ def test_handle_failure(self):
context_arg_2 = mock_on_retry_2.call_args[0][0]
assert context_arg_2 and "task_instance" in context_arg_2

def _env_var_check_callback(self):
self.assertEqual('test_echo_env_variables', os.environ['AIRFLOW_CTX_DAG_ID'])
self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID'])
self.assertEqual(DEFAULT_DATE.isoformat(),
os.environ['AIRFLOW_CTX_EXECUTION_DATE'])
self.assertEqual('manual__' + DEFAULT_DATE.isoformat(),
os.environ['AIRFLOW_CTX_DAG_RUN_ID'])

def test_echo_env_variables(self):
dag = DAG('test_echo_env_variables', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
op = PythonOperator(task_id='hive_in_python_op',
dag=dag,
python_callable=self._env_var_check_callback)
dag.create_dagrun(
run_id='manual__' + DEFAULT_DATE.isoformat(),
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
state=State.RUNNING,
external_trigger=False)
ti = TI(task=op, execution_date=DEFAULT_DATE)
ti.state = State.RUNNING
session = settings.Session()
session.merge(ti)
session.commit()
ti._run_raw_task()
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)


@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
Expand Down
33 changes: 0 additions & 33 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import copy
import datetime
import logging
import os
import sys
import unittest
import unittest.mock
Expand Down Expand Up @@ -120,12 +119,6 @@ def _assert_calls_equal(self, first, second):
self.assertDictEqual(first.kwargs, second.kwargs)


@unittest.mock.patch('os.environ', {
'AIRFLOW_CTX_DAG_ID': None,
'AIRFLOW_CTX_TASK_ID': None,
'AIRFLOW_CTX_EXECUTION_DATE': None,
'AIRFLOW_CTX_DAG_RUN_ID': None
})
class TestPythonOperator(TestPythonBase):

def do_run(self):
Expand Down Expand Up @@ -249,32 +242,6 @@ def test_python_operator_shallow_copy_attr(self):
self.assertEqual(id(original_task.python_callable),
id(new_task.python_callable))

def _env_var_check_callback(self):
self.assertEqual('test_dag', os.environ['AIRFLOW_CTX_DAG_ID'])
self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID'])
self.assertEqual(DEFAULT_DATE.isoformat(),
os.environ['AIRFLOW_CTX_EXECUTION_DATE'])
self.assertEqual('manual__' + DEFAULT_DATE.isoformat(),
os.environ['AIRFLOW_CTX_DAG_RUN_ID'])

def test_echo_env_variables(self):
"""
Test that env variables are exported correctly to the
python callback in the task.
"""
self.dag.create_dagrun(
run_id='manual__' + DEFAULT_DATE.isoformat(),
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
state=State.RUNNING,
external_trigger=False,
)

op = PythonOperator(task_id='hive_in_python_op',
dag=self.dag,
python_callable=self._env_var_check_callback)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_conflicting_kwargs(self):
self.dag.create_dagrun(
run_id='manual__' + DEFAULT_DATE.isoformat(),
Expand Down