Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions tests/contrib/operators/test_sftp_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import os
import unittest
from base64 import b64encode
from unittest import mock

from airflow import models
from airflow import AirflowException, models
from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator
from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.models import DAG, TaskInstance
Expand All @@ -32,6 +33,7 @@

TEST_DAG_ID = 'unit_tests'
DEFAULT_DATE = datetime(2017, 1, 1)
TEST_CONN_ID = "conn_id_for_testing"


def reset(dag_id=TEST_DAG_ID):
Expand Down Expand Up @@ -366,11 +368,8 @@ def test_file_transfer_with_intermediate_dir_error_get(self):
content_received = file.read()
self.assertEqual(content_received.strip(), test_remote_file_content)

@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
def test_arg_checking(self):
from airflow.exceptions import AirflowException
conn_id = "conn_id_for_testing"
os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost"

# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
Expand All @@ -387,7 +386,7 @@ def test_arg_checking(self):
task_1 = SFTPOperator(
task_id="test_sftp",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=conn_id,
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
Expand All @@ -397,11 +396,11 @@ def test_arg_checking(self):
task_1.execute(None)
except Exception:
pass
self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id)
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)

task_2 = SFTPOperator(
task_id="test_sftp",
ssh_conn_id=conn_id, # no ssh_hook provided
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
Expand All @@ -411,13 +410,13 @@ def test_arg_checking(self):
task_2.execute(None)
except Exception:
pass
self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id)
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)

# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
ssh_conn_id=conn_id,
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
Expand Down
24 changes: 12 additions & 12 deletions tests/contrib/operators/test_ssh_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
# under the License.

import unittest
import unittest.mock
from base64 import b64encode

from airflow import models
from airflow import AirflowException, models
from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.models import DAG, TaskInstance
from airflow.settings import Session
Expand All @@ -29,6 +30,8 @@
from tests.test_utils.config import conf_vars

TEST_DAG_ID = 'unit_tests'
TEST_CONN_ID = "conn_id_for_testing"
TIMEOUT = 5
DEFAULT_DATE = datetime(2017, 1, 1)


Expand Down Expand Up @@ -148,13 +151,10 @@ def test_no_output_command(self):
self.assertIsNotNone(ti.duration)
self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'')

@unittest.mock.patch('os.environ', {
'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"
})
def test_arg_checking(self):
import os
from airflow.exceptions import AirflowException
conn_id = "conn_id_for_testing"
TIMEOUT = 5
os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost"

# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
Expand All @@ -166,7 +166,7 @@ def test_arg_checking(self):
task_1 = SSHOperator(
task_id="test_1",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=conn_id,
ssh_conn_id=TEST_CONN_ID,
command="echo -n airflow",
timeout=TIMEOUT,
dag=self.dag
Expand All @@ -175,11 +175,11 @@ def test_arg_checking(self):
task_1.execute(None)
except Exception:
pass
self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id)
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)

task_2 = SSHOperator(
task_id="test_2",
ssh_conn_id=conn_id, # no ssh_hook provided
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
command="echo -n airflow",
timeout=TIMEOUT,
dag=self.dag
Expand All @@ -188,13 +188,13 @@ def test_arg_checking(self):
task_2.execute(None)
except Exception:
pass
self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id)
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)

# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SSHOperator(
task_id="test_3",
ssh_hook=self.hook,
ssh_conn_id=conn_id,
ssh_conn_id=TEST_CONN_ID,
command="echo -n airflow",
timeout=TIMEOUT,
dag=self.dag
Expand Down
11 changes: 5 additions & 6 deletions tests/contrib/utils/test_sendgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def setUp(self):
}

# Test the right email is constructed.
@mock.patch('os.environ', dict(os.environ, SENDGRID_MAIL_FROM='foo@bar.com'))
@mock.patch.dict('os.environ', SENDGRID_MAIL_FROM='foo@bar.com')
@mock.patch('airflow.contrib.utils.sendgrid._post_sendgrid_mail')
def test_send_email_sendgrid_correct_email(self, mock_post):
with tempfile.NamedTemporaryFile(mode='wt', suffix='.txt') as f:
Expand Down Expand Up @@ -92,11 +92,10 @@ def test_send_email_sendgrid_correct_email(self, mock_post):
mock_post.assert_called_once_with(expected_mail_data)

# Test the right email is constructed.
@mock.patch(
@mock.patch.dict(
'os.environ',
dict(os.environ,
SENDGRID_MAIL_FROM='foo@bar.com',
SENDGRID_MAIL_SENDER='Foo')
SENDGRID_MAIL_FROM='foo@bar.com',
SENDGRID_MAIL_SENDER='Foo'
)
@mock.patch('airflow.contrib.utils.sendgrid._post_sendgrid_mail')
def test_send_email_sendgrid_correct_email_extras(self, mock_post):
Expand All @@ -105,7 +104,7 @@ def test_send_email_sendgrid_correct_email_extras(self, mock_post):
categories=self.categories)
mock_post.assert_called_once_with(self.expected_mail_data_extras)

@mock.patch('os.environ', {})
@mock.patch.dict('os.environ', clear=True)
@mock.patch('airflow.contrib.utils.sendgrid._post_sendgrid_mail')
def test_send_email_sendgrid_sender(self, mock_post):
send_email(self.recepients, self.subject, self.html_content, cc=self.carbon_copy, bcc=self.bcc,
Expand Down
78 changes: 43 additions & 35 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,26 +809,20 @@ def test_config_throw_error_when_original_and_fallback_is_absent(self):
def test_config_override_original_when_non_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
value = "some value"
self.assertNotIn(key, os.environ)

os.environ[key] = value
FERNET_KEY = conf.get('core', 'FERNET_KEY')
self.assertEqual(value, FERNET_KEY)
with unittest.mock.patch.dict('os.environ', {key: value}):
FERNET_KEY = conf.get('core', 'FERNET_KEY')

# restore the envvar back to the original state
del os.environ[key]
self.assertEqual(value, FERNET_KEY)

def test_config_override_original_when_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
value = ""
self.assertNotIn(key, os.environ)

os.environ[key] = value
FERNET_KEY = conf.get('core', 'FERNET_KEY')
self.assertEqual(value, FERNET_KEY)
with unittest.mock.patch.dict('os.environ', {key: value}):
FERNET_KEY = conf.get('core', 'FERNET_KEY')

# restore the envvar back to the original state
del os.environ[key]
self.assertEqual(value, FERNET_KEY)

def test_round_time(self):

Expand Down Expand Up @@ -2042,17 +2036,10 @@ def get_conn(self):
class TestConnection(unittest.TestCase):
def setUp(self):
utils.db.initdb()
os.environ['AIRFLOW_CONN_TEST_URI'] = (
'postgres://username:password@ec2.compute.com:5432/the_database')
os.environ['AIRFLOW_CONN_TEST_URI_NO_CREDS'] = (
'postgres://ec2.compute.com/the_database')

def tearDown(self):
env_vars = ['AIRFLOW_CONN_TEST_URI', 'AIRFLOW_CONN_AIRFLOW_DB']
for ev in env_vars:
if ev in os.environ:
del os.environ[ev]

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
})
def test_using_env_var(self):
c = SqliteHook.get_connection(conn_id='test_uri')
self.assertEqual('ec2.compute.com', c.host)
Expand All @@ -2061,6 +2048,9 @@ def test_using_env_var(self):
self.assertEqual('password', c.password)
self.assertEqual(5432, c.port)

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
})
def test_using_unix_socket_env_var(self):
c = SqliteHook.get_connection(conn_id='test_uri_no_creds')
self.assertEqual('ec2.compute.com', c.host)
Expand All @@ -2083,16 +2073,20 @@ def test_env_var_priority(self):
c = SqliteHook.get_connection(conn_id='airflow_db')
self.assertNotEqual('ec2.compute.com', c.host)

os.environ['AIRFLOW_CONN_AIRFLOW_DB'] = \
'postgres://username:password@ec2.compute.com:5432/the_database'
c = SqliteHook.get_connection(conn_id='airflow_db')
self.assertEqual('ec2.compute.com', c.host)
self.assertEqual('the_database', c.schema)
self.assertEqual('username', c.login)
self.assertEqual('password', c.password)
self.assertEqual(5432, c.port)
del os.environ['AIRFLOW_CONN_AIRFLOW_DB']

with unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database',
}):
c = SqliteHook.get_connection(conn_id='airflow_db')
self.assertEqual('ec2.compute.com', c.host)
self.assertEqual('the_database', c.schema)
self.assertEqual('username', c.login)
self.assertEqual('password', c.password)
self.assertEqual(5432, c.port)

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
})
def test_dbapi_get_uri(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
Expand All @@ -2101,13 +2095,21 @@ def test_dbapi_get_uri(self):
hook2 = conn2.get_hook()
self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri())

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
})
def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
engine = hook.get_sqlalchemy_engine()
self.assertIsInstance(engine, sqlalchemy.engine.Engine)
self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url))

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
})
def test_get_connections_env_var(self):
conns = SqliteHook.get_connections(conn_id='test_uri')
assert len(conns) == 1
Expand Down Expand Up @@ -2137,16 +2139,19 @@ def test_init_proxy_user(self):
@unittest.skipIf(HDFSHook is None,
"Skipping test because HDFSHook is not installed")
class TestHDFSHook(unittest.TestCase):
def setUp(self):
os.environ['AIRFLOW_CONN_HDFS_DEFAULT'] = 'hdfs://localhost:8020'

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020',
})
def test_get_client(self):
client = HDFSHook(proxy_user='foo').get_conn()
self.assertIsInstance(client, snakebite.client.Client)
self.assertEqual('localhost', client.host)
self.assertEqual(8020, client.port)
self.assertEqual('foo', client.service.channel.effective_user)

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020',
})
@mock.patch('airflow.hooks.hdfs_hook.AutoConfigClient')
@mock.patch('airflow.hooks.hdfs_hook.HDFSHook.get_connections')
def test_get_autoconfig_client(self, mock_get_connections,
Expand All @@ -2159,6 +2164,9 @@ def test_get_autoconfig_client(self, mock_get_connections,
MockAutoConfigClient.assert_called_once_with(effective_user='foo',
use_sasl=False)

@unittest.mock.patch.dict('os.environ', {
'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020',
})
@mock.patch('airflow.hooks.hdfs_hook.AutoConfigClient')
def test_get_autoconfig_client_no_conn(self, MockAutoConfigClient):
HDFSHook(hdfs_conn_id='hdfs_missing', autoconfig=True).get_conn()
Expand Down
2 changes: 0 additions & 2 deletions tests/gcp/operators/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ def setUp(self):
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE)
if CREDENTIALS in os.environ:
del os.environ[CREDENTIALS]

def test_template_fields(self):
self.assertTrue(set(KubernetesPodOperator.template_fields).issubset(
Expand Down
Loading