diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 86587a8a4dc7..0bc4b7132138 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -19,6 +19,7 @@ """Command-line interface""" import argparse +import json import os import textwrap from argparse import RawTextHelpFormatter @@ -459,6 +460,10 @@ class CLIFactory: action="store_true", help="Open debugger on uncaught exception", ), + 'env_vars': Arg( + ("--env-vars", ), + help="Set env var in both parsing time and runtime for each of entry supplied in a JSON dict", + type=json.loads), # connections 'conn_id': Arg( ('conn_id',), @@ -734,7 +739,7 @@ class CLIFactory: "dependencies or recording its state in the database"), 'args': ( 'dag_id', 'task_id', 'execution_date', 'subdir', 'dry_run', - 'task_params', 'post_mortem'), + 'task_params', 'post_mortem', 'env_vars'), }, { 'func': lazy_load_command('airflow.cli.commands.task_command.task_states_for_dag_run'), diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index de1aacc8f981..1c05357612c5 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -320,6 +320,11 @@ def task_test(args, dag=None): if not already_has_stream_handler: logging.getLogger('airflow.task').propagate = True + env_vars = {'AIRFLOW_TEST_MODE': 'True'} + if args.env_vars: + env_vars.update(args.env_vars) + os.environ.update(env_vars) + dag = dag or get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index 76ac246f2828..5b56a9ae37eb 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -18,6 +18,7 @@ """Example DAG demonstrating the usage of the params arguments in templated arguments.""" +import os from datetime import timedelta from airflow import DAG @@ -70,4 +71,22 @@ def my_py_command(test_mode, params): dag=dag, ) + +def print_env_vars(test_mode): + """ + Print out the "foo" param passed in via + `airflow tasks test example_passing_params_via_test_command env_var_test_task + --env-vars '{"foo":"bar"}'` + """ + if test_mode: + print("foo={}".format(os.environ.get('foo'))) + print("AIRFLOW_TEST_MODE={}".format(os.environ.get('AIRFLOW_TEST_MODE'))) + + +env_var_test_task = PythonOperator( + task_id='env_var_test_task', + python_callable=print_env_vars, + dag=dag +) + run_this >> also_run_this diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 2441d3653541..b76de741ba58 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -119,6 +119,15 @@ def test_cli_test_with_params(self): 'tasks', 'test', 'example_passing_params_via_test_command', 'also_run_this', '--task-params', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) + def test_cli_test_with_env_vars(self): + with redirect_stdout(io.StringIO()) as stdout: + task_command.task_test(self.parser.parse_args([ + 'tasks', 'test', 'example_passing_params_via_test_command', 'env_var_test_task', + '--env-vars', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) + output = stdout.getvalue() + self.assertIn('foo=bar', output) + self.assertIn('AIRFLOW_TEST_MODE=True', output) + def test_cli_run(self): task_command.task_run(self.parser.parse_args([ 'tasks', 'run', 'example_bash_operator', 'runme_0', '--local',