diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 662b2fad1a588..3d4c6a7e70e2a 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -20,8 +20,9 @@ import os import unittest from base64 import b64encode +from unittest import mock -from airflow import models +from airflow import AirflowException, models from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator from airflow.contrib.operators.ssh_operator import SSHOperator from airflow.models import DAG, TaskInstance @@ -32,6 +33,7 @@ TEST_DAG_ID = 'unit_tests' DEFAULT_DATE = datetime(2017, 1, 1) +TEST_CONN_ID = "conn_id_for_testing" def reset(dag_id=TEST_DAG_ID): @@ -366,11 +368,8 @@ def test_file_transfer_with_intermediate_dir_error_get(self): content_received = file.read() self.assertEqual(content_received.strip(), test_remote_file_content) + @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) def test_arg_checking(self): - from airflow.exceptions import AirflowException - conn_id = "conn_id_for_testing" - os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" - # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided with self.assertRaisesRegex(AirflowException, "Cannot operate without ssh_hook or ssh_conn_id."): @@ -387,7 +386,7 @@ def test_arg_checking(self): task_1 = SFTPOperator( task_id="test_sftp", ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook - ssh_conn_id=conn_id, + ssh_conn_id=TEST_CONN_ID, local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, @@ -397,11 +396,11 @@ def test_arg_checking(self): task_1.execute(None) except Exception: pass - self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID) task_2 = SFTPOperator( task_id="test_sftp", - ssh_conn_id=conn_id, # no ssh_hook provided + ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, @@ -411,13 +410,13 @@ def test_arg_checking(self): task_2.execute(None) except Exception: pass - self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID) # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id task_3 = SFTPOperator( task_id="test_sftp", ssh_hook=self.hook, - ssh_conn_id=conn_id, + ssh_conn_id=TEST_CONN_ID, local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index e8cb8d00d72a2..aa1291112ae16 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -18,9 +18,10 @@ # under the License. import unittest +import unittest.mock from base64 import b64encode -from airflow import models +from airflow import AirflowException, models from airflow.contrib.operators.ssh_operator import SSHOperator from airflow.models import DAG, TaskInstance from airflow.settings import Session @@ -29,6 +30,8 @@ from tests.test_utils.config import conf_vars TEST_DAG_ID = 'unit_tests' +TEST_CONN_ID = "conn_id_for_testing" +TIMEOUT = 5 DEFAULT_DATE = datetime(2017, 1, 1) @@ -148,13 +151,10 @@ def test_no_output_command(self): self.assertIsNotNone(ti.duration) self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'') + @unittest.mock.patch('os.environ', { + 'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost" + }) def test_arg_checking(self): - import os - from airflow.exceptions import AirflowException - conn_id = "conn_id_for_testing" - TIMEOUT = 5 - os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" - # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided with self.assertRaisesRegex(AirflowException, "Cannot operate without ssh_hook or ssh_conn_id."): @@ -166,7 +166,7 @@ def test_arg_checking(self): task_1 = SSHOperator( task_id="test_1", ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook - ssh_conn_id=conn_id, + ssh_conn_id=TEST_CONN_ID, command="echo -n airflow", timeout=TIMEOUT, dag=self.dag @@ -175,11 +175,11 @@ def test_arg_checking(self): task_1.execute(None) except Exception: pass - self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID) task_2 = SSHOperator( task_id="test_2", - ssh_conn_id=conn_id, # no ssh_hook provided + ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided command="echo -n airflow", timeout=TIMEOUT, dag=self.dag @@ -188,13 +188,13 @@ def test_arg_checking(self): task_2.execute(None) except Exception: pass - self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID) # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id task_3 = SSHOperator( task_id="test_3", ssh_hook=self.hook, - ssh_conn_id=conn_id, + ssh_conn_id=TEST_CONN_ID, command="echo -n airflow", timeout=TIMEOUT, dag=self.dag diff --git a/tests/contrib/utils/test_sendgrid.py b/tests/contrib/utils/test_sendgrid.py index bc76b8048c91c..c75736311c4f3 100644 --- a/tests/contrib/utils/test_sendgrid.py +++ b/tests/contrib/utils/test_sendgrid.py @@ -64,7 +64,7 @@ def setUp(self): } # Test the right email is constructed. - @mock.patch('os.environ', dict(os.environ, SENDGRID_MAIL_FROM='foo@bar.com')) + @mock.patch.dict('os.environ', SENDGRID_MAIL_FROM='foo@bar.com') @mock.patch('airflow.contrib.utils.sendgrid._post_sendgrid_mail') def test_send_email_sendgrid_correct_email(self, mock_post): with tempfile.NamedTemporaryFile(mode='wt', suffix='.txt') as f: @@ -92,11 +92,10 @@ def test_send_email_sendgrid_correct_email(self, mock_post): mock_post.assert_called_once_with(expected_mail_data) # Test the right email is constructed. - @mock.patch( + @mock.patch.dict( 'os.environ', - dict(os.environ, - SENDGRID_MAIL_FROM='foo@bar.com', - SENDGRID_MAIL_SENDER='Foo') + SENDGRID_MAIL_FROM='foo@bar.com', + SENDGRID_MAIL_SENDER='Foo' ) @mock.patch('airflow.contrib.utils.sendgrid._post_sendgrid_mail') def test_send_email_sendgrid_correct_email_extras(self, mock_post): @@ -105,7 +104,7 @@ def test_send_email_sendgrid_correct_email_extras(self, mock_post): categories=self.categories) mock_post.assert_called_once_with(self.expected_mail_data_extras) - @mock.patch('os.environ', {}) + @mock.patch.dict('os.environ', clear=True) @mock.patch('airflow.contrib.utils.sendgrid._post_sendgrid_mail') def test_send_email_sendgrid_sender(self, mock_post): send_email(self.recepients, self.subject, self.html_content, cc=self.carbon_copy, bcc=self.bcc, diff --git a/tests/core.py b/tests/core.py index 91645cefee757..a932393d1a67c 100644 --- a/tests/core.py +++ b/tests/core.py @@ -809,26 +809,20 @@ def test_config_throw_error_when_original_and_fallback_is_absent(self): def test_config_override_original_when_non_empty_envvar_is_provided(self): key = "AIRFLOW__CORE__FERNET_KEY" value = "some value" - self.assertNotIn(key, os.environ) - os.environ[key] = value - FERNET_KEY = conf.get('core', 'FERNET_KEY') - self.assertEqual(value, FERNET_KEY) + with unittest.mock.patch.dict('os.environ', {key: value}): + FERNET_KEY = conf.get('core', 'FERNET_KEY') - # restore the envvar back to the original state - del os.environ[key] + self.assertEqual(value, FERNET_KEY) def test_config_override_original_when_empty_envvar_is_provided(self): key = "AIRFLOW__CORE__FERNET_KEY" value = "" - self.assertNotIn(key, os.environ) - os.environ[key] = value - FERNET_KEY = conf.get('core', 'FERNET_KEY') - self.assertEqual(value, FERNET_KEY) + with unittest.mock.patch.dict('os.environ', {key: value}): + FERNET_KEY = conf.get('core', 'FERNET_KEY') - # restore the envvar back to the original state - del os.environ[key] + self.assertEqual(value, FERNET_KEY) def test_round_time(self): @@ -2042,17 +2036,10 @@ def get_conn(self): class TestConnection(unittest.TestCase): def setUp(self): utils.db.initdb() - os.environ['AIRFLOW_CONN_TEST_URI'] = ( - 'postgres://username:password@ec2.compute.com:5432/the_database') - os.environ['AIRFLOW_CONN_TEST_URI_NO_CREDS'] = ( - 'postgres://ec2.compute.com/the_database') - - def tearDown(self): - env_vars = ['AIRFLOW_CONN_TEST_URI', 'AIRFLOW_CONN_AIRFLOW_DB'] - for ev in env_vars: - if ev in os.environ: - del os.environ[ev] + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + }) def test_using_env_var(self): c = SqliteHook.get_connection(conn_id='test_uri') self.assertEqual('ec2.compute.com', c.host) @@ -2061,6 +2048,9 @@ def test_using_env_var(self): self.assertEqual('password', c.password) self.assertEqual(5432, c.port) + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }) def test_using_unix_socket_env_var(self): c = SqliteHook.get_connection(conn_id='test_uri_no_creds') self.assertEqual('ec2.compute.com', c.host) @@ -2083,16 +2073,20 @@ def test_env_var_priority(self): c = SqliteHook.get_connection(conn_id='airflow_db') self.assertNotEqual('ec2.compute.com', c.host) - os.environ['AIRFLOW_CONN_AIRFLOW_DB'] = \ - 'postgres://username:password@ec2.compute.com:5432/the_database' - c = SqliteHook.get_connection(conn_id='airflow_db') - self.assertEqual('ec2.compute.com', c.host) - self.assertEqual('the_database', c.schema) - self.assertEqual('username', c.login) - self.assertEqual('password', c.password) - self.assertEqual(5432, c.port) - del os.environ['AIRFLOW_CONN_AIRFLOW_DB'] - + with unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database', + }): + c = SqliteHook.get_connection(conn_id='airflow_db') + self.assertEqual('ec2.compute.com', c.host) + self.assertEqual('the_database', c.schema) + self.assertEqual('username', c.login) + self.assertEqual('password', c.password) + self.assertEqual(5432, c.port) + + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }) def test_dbapi_get_uri(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() @@ -2101,6 +2095,10 @@ def test_dbapi_get_uri(self): hook2 = conn2.get_hook() self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri()) + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }) def test_dbapi_get_sqlalchemy_engine(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() @@ -2108,6 +2106,10 @@ def test_dbapi_get_sqlalchemy_engine(self): self.assertIsInstance(engine, sqlalchemy.engine.Engine) self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url)) + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }) def test_get_connections_env_var(self): conns = SqliteHook.get_connections(conn_id='test_uri') assert len(conns) == 1 @@ -2137,9 +2139,9 @@ def test_init_proxy_user(self): @unittest.skipIf(HDFSHook is None, "Skipping test because HDFSHook is not installed") class TestHDFSHook(unittest.TestCase): - def setUp(self): - os.environ['AIRFLOW_CONN_HDFS_DEFAULT'] = 'hdfs://localhost:8020' - + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020', + }) def test_get_client(self): client = HDFSHook(proxy_user='foo').get_conn() self.assertIsInstance(client, snakebite.client.Client) @@ -2147,6 +2149,9 @@ def test_get_client(self): self.assertEqual(8020, client.port) self.assertEqual('foo', client.service.channel.effective_user) + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020', + }) @mock.patch('airflow.hooks.hdfs_hook.AutoConfigClient') @mock.patch('airflow.hooks.hdfs_hook.HDFSHook.get_connections') def test_get_autoconfig_client(self, mock_get_connections, @@ -2159,6 +2164,9 @@ def test_get_autoconfig_client(self, mock_get_connections, MockAutoConfigClient.assert_called_once_with(effective_user='foo', use_sasl=False) + @unittest.mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020', + }) @mock.patch('airflow.hooks.hdfs_hook.AutoConfigClient') def test_get_autoconfig_client_no_conn(self, MockAutoConfigClient): HDFSHook(hdfs_conn_id='hdfs_missing', autoconfig=True).get_conn() diff --git a/tests/gcp/operators/test_kubernetes_engine.py b/tests/gcp/operators/test_kubernetes_engine.py index bcd61694747a1..43c29ed17e0bb 100644 --- a/tests/gcp/operators/test_kubernetes_engine.py +++ b/tests/gcp/operators/test_kubernetes_engine.py @@ -145,8 +145,6 @@ def setUp(self): name=TASK_NAME, namespace=NAMESPACE, image=IMAGE) - if CREDENTIALS in os.environ: - del os.environ[CREDENTIALS] def test_template_fields(self): self.assertTrue(set(KubernetesPodOperator.template_fields).issubset( diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py index 52af425917e02..f1e4007c5a3c3 100644 --- a/tests/hooks/test_hive_hook.py +++ b/tests/hooks/test_hive_hook.py @@ -107,23 +107,19 @@ def test_run_cli_with_hive_conf(self): dag_run_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ 'env_var_format'] - os.environ[dag_id_ctx_var_name] = 'test_dag_id' - os.environ[task_id_ctx_var_name] = 'test_task_id' - os.environ[execution_date_ctx_var_name] = 'test_execution_date' - os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id' - - hook = HiveCliHook() - output = hook.run_cli(hql=hql, hive_conf={'key': 'value'}) - self.assertIn('value', output) - self.assertIn('test_dag_id', output) - self.assertIn('test_task_id', output) - self.assertIn('test_execution_date', output) - self.assertIn('test_dag_run_id', output) - - del os.environ[dag_id_ctx_var_name] - del os.environ[task_id_ctx_var_name] - del os.environ[execution_date_ctx_var_name] - del os.environ[dag_run_id_ctx_var_name] + with mock.patch.dict('os.environ', { + dag_id_ctx_var_name: 'test_dag_id', + task_id_ctx_var_name: 'test_task_id', + execution_date_ctx_var_name: 'test_execution_date', + dag_run_id_ctx_var_name: 'test_dag_run_id', + }): + hook = HiveCliHook() + output = hook.run_cli(hql=hql, hive_conf={'key': 'value'}) + self.assertIn('value', output) + self.assertIn('test_dag_id', output) + self.assertIn('test_task_id', output) + self.assertIn('test_execution_date', output) + self.assertIn('test_dag_run_id', output) @mock.patch('airflow.hooks.hive_hooks.HiveCliHook.run_cli') def test_load_file(self, mock_run_cli): @@ -415,21 +411,20 @@ def test_get_conn_with_password(self, mock_connect): from airflow.hooks.base_hook import CONN_ENV_PREFIX conn_id = "conn_with_password" conn_env = CONN_ENV_PREFIX + conn_id.upper() - conn_value = os.environ.get(conn_env) - os.environ[conn_env] = "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP" - - HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn() - mock_connect.assert_called_once_with( - host='localhost', - port=10000, - auth='LDAP', - kerberos_service_name=None, - username='conn_id', - password='conn_pass', - database='default') - - if conn_value: - os.environ[conn_env] = conn_value + + with mock.patch.dict( + 'os.environ', + {conn_env: "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"} + ): + HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn() + mock_connect.assert_called_once_with( + host='localhost', + port=10000, + auth='LDAP', + kerberos_service_name=None, + username='conn_id', + password='conn_pass', + database='default') def test_get_records(self): hook = HiveServer2Hook() @@ -504,18 +499,20 @@ def test_get_results_with_hive_conf(self): os.environ[execution_date_ctx_var_name] = 'test_execution_date' os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id' - hook = HiveServer2Hook() - output = '\n'.join(res_tuple[0] - for res_tuple - in hook.get_results(hql=hql, - hive_conf={'key': 'value'})['data']) + with mock.patch.dict('os.environ', { + dag_id_ctx_var_name: 'test_dag_id', + task_id_ctx_var_name: 'test_task_id', + execution_date_ctx_var_name: 'test_execution_date', + dag_run_id_ctx_var_name: 'test_dag_run_id', + + }): + hook = HiveServer2Hook() + output = '\n'.join(res_tuple[0] + for res_tuple + in hook.get_results(hql=hql, + hive_conf={'key': 'value'})['data']) self.assertIn('value', output) self.assertIn('test_dag_id', output) self.assertIn('test_task_id', output) self.assertIn('test_execution_date', output) self.assertIn('test_dag_run_id', output) - - del os.environ[dag_id_ctx_var_name] - del os.environ[task_id_ctx_var_name] - del os.environ[execution_date_ctx_var_name] - del os.environ[dag_run_id_ctx_var_name] diff --git a/tests/operators/test_http_operator.py b/tests/operators/test_http_operator.py index 1555d44f3d667..36c7b4020ed5d 100644 --- a/tests/operators/test_http_operator.py +++ b/tests/operators/test_http_operator.py @@ -17,7 +17,6 @@ # specific language governing permissions and limitations # under the License. -import os import unittest import requests_mock @@ -27,9 +26,8 @@ from tests.compat import mock +@mock.patch.dict('os.environ', AIRFLOW_CONN_HTTP_EXAMPLE='http://www.example.com') class TestSimpleHttpOp(unittest.TestCase): - def setUp(self): - os.environ['AIRFLOW_CONN_HTTP_EXAMPLE'] = 'http://www.example.com' @requests_mock.mock() def test_response_in_logs(self, m): diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 7cbe080b85162..8d30492c5751c 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -21,6 +21,7 @@ import logging import os import unittest +import unittest.mock from collections import namedtuple from datetime import date, timedelta @@ -61,6 +62,12 @@ def recording_function(*args): return recording_function +@unittest.mock.patch('os.environ', { + 'AIRFLOW_CTX_DAG_ID': None, + 'AIRFLOW_CTX_TASK_ID': None, + 'AIRFLOW_CTX_EXECUTION_DATE': None, + 'AIRFLOW_CTX_DAG_RUN_ID': None +}) class TestPythonOperator(unittest.TestCase): @classmethod def setUpClass(cls): @@ -89,10 +96,6 @@ def tearDown(self): session.query(DagRun).delete() session.query(TI).delete() - for var in TI_CONTEXT_ENV_VARS: - if var in os.environ: - del os.environ[var] - def do_run(self): self.run = True diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 182158a0f9c55..79c4ca4c526ed 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -17,7 +17,6 @@ # specific language governing permissions and limitations # under the License. -import contextlib import os import unittest import warnings @@ -28,56 +27,40 @@ from airflow.configuration import AirflowConfigParser, conf, parameterized_config -@contextlib.contextmanager -def env_vars(**vars): - original = {} - for key, value in vars.items(): - original[key] = os.environ.get(key) - if value is not None: - os.environ[key] = value - else: - os.environ.pop(key, None) - yield - for key, value in original.items(): - if value is not None: - os.environ[key] = value - else: - os.environ.pop(key, None) - - +@unittest.mock.patch.dict('os.environ', { + 'AIRFLOW__TESTSECTION__TESTKEY': 'testvalue', + 'AIRFLOW__TESTSECTION__TESTPERCENT': 'with%percent' +}) class TestConf(unittest.TestCase): @classmethod def setUpClass(cls): - os.environ['AIRFLOW__TESTSECTION__TESTKEY'] = 'testvalue' - os.environ['AIRFLOW__TESTSECTION__TESTPERCENT'] = 'with%percent' conf.set('core', 'percent', 'with%%inside') - @classmethod - def tearDownClass(cls): - del os.environ['AIRFLOW__TESTSECTION__TESTKEY'] - del os.environ['AIRFLOW__TESTSECTION__TESTPERCENT'] - def test_airflow_home_default(self): - with env_vars(AIRFLOW_HOME=None): + with unittest.mock.patch.dict('os.environ'): + if 'AIRFLOW_HOME' in os.environ: + del os.environ['AIRFLOW_HOME'] self.assertEqual( configuration.get_airflow_home(), configuration.expand_env_var('~/airflow')) def test_airflow_home_override(self): - with env_vars(AIRFLOW_HOME='/path/to/airflow'): + with unittest.mock.patch.dict('os.environ', AIRFLOW_HOME='/path/to/airflow'): self.assertEqual( configuration.get_airflow_home(), '/path/to/airflow') def test_airflow_config_default(self): - with env_vars(AIRFLOW_CONFIG=None): + with unittest.mock.patch.dict('os.environ'): + if 'AIRFLOW_CONFIG' in os.environ: + del os.environ['AIRFLOW_CONFIG'] self.assertEqual( configuration.get_airflow_config('/home/airflow'), configuration.expand_env_var('/home/airflow/airflow.cfg')) def test_airflow_config_override(self): - with env_vars(AIRFLOW_CONFIG='/path/to/airflow/airflow.cfg'): + with unittest.mock.patch.dict('os.environ', AIRFLOW_CONFIG='/path/to/airflow/airflow.cfg'): self.assertEqual( configuration.get_airflow_config('/home//airflow'), '/path/to/airflow/airflow.cfg') @@ -98,13 +81,18 @@ def test_env_var_config(self): self.assertTrue(conf.has_option('testsection', 'testkey')) - os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY'] = 'nested' - opt = conf.get('kubernetes_environment_variables', 'AIRFLOW__TESTSECTION__TESTKEY') - self.assertEqual(opt, 'nested') - del os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY'] - + with unittest.mock.patch.dict( + 'os.environ', + AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested' + ): + opt = conf.get('kubernetes_environment_variables', 'AIRFLOW__TESTSECTION__TESTKEY') + self.assertEqual(opt, 'nested') + + @mock.patch.dict( + 'os.environ', + AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested' + ) def test_conf_as_dict(self): - os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY'] = 'nested' cfg_dict = conf.as_dict() # test that configs are picked up @@ -117,7 +105,6 @@ def test_conf_as_dict(self): self.assertEqual( cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'], '< hidden >') - del os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY'] def test_conf_as_dict_source(self): # test display_source @@ -364,9 +351,8 @@ def test_deprecated_options(self): conf.remove_option('celery', 'worker_concurrency') with self.assertWarns(DeprecationWarning): - os.environ['AIRFLOW__CELERY__CELERYD_CONCURRENCY'] = '99' - self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99) - os.environ.pop('AIRFLOW__CELERY__CELERYD_CONCURRENCY') + with mock.patch.dict('os.environ', AIRFLOW__CELERY__CELERYD_CONCURRENCY="99"): + self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99) with self.assertWarns(DeprecationWarning): conf.set('celery', 'celeryd_concurrency', '99') @@ -414,13 +400,13 @@ def make_config(): self.assertEqual(test_conf.get('core', 'task_runner'), 'StandardTaskRunner') with self.assertWarns(FutureWarning): - with env_vars(AIRFLOW__CORE__TASK_RUNNER='BashTaskRunner'): + with unittest.mock.patch.dict('os.environ', AIRFLOW__CORE__TASK_RUNNER='BashTaskRunner'): test_conf = make_config() self.assertEqual(test_conf.get('core', 'task_runner'), 'StandardTaskRunner') with warnings.catch_warnings(record=True) as w: - with env_vars(AIRFLOW__CORE__TASK_RUNNER='NotBashTaskRunner'): + with unittest.mock.patch.dict('os.environ', AIRFLOW__CORE__TASK_RUNNER='NotBashTaskRunner'): test_conf = make_config() self.assertEqual(test_conf.get('core', 'task_runner'), 'NotBashTaskRunner') diff --git a/tests/test_impersonation.py b/tests/test_impersonation.py index 5ae6d1f206cc3..cc723d9bd480d 100644 --- a/tests/test_impersonation.py +++ b/tests/test_impersonation.py @@ -161,20 +161,16 @@ def test_no_impersonation(self): 'test_superuser', ) + @unittest.mock.patch.dict('os.environ', AIRFLOW__CORE__DEFAULT_IMPERSONATION=TEST_USER) def test_default_impersonation(self): """ If default_impersonation=TEST_USER, tests that the job defaults to running as TEST_USER for a test without run_as_user set """ - os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION'] = TEST_USER - - try: - self.run_backfill( - 'test_default_impersonation', - 'test_deelevated_user' - ) - finally: - del os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION'] + self.run_backfill( + 'test_default_impersonation', + 'test_deelevated_user' + ) def test_impersonation_subdag(self): """