Skip to content

Commit

Permalink
Fix AWS DataSync tests failing (#11020)
Browse files Browse the repository at this point in the history
closes #10985
  • Loading branch information
baolsen committed Nov 25, 2020
1 parent 58e21ed commit 663259d
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 72 deletions.
8 changes: 6 additions & 2 deletions airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Expand Up @@ -107,8 +107,12 @@ def get_cloudwatch_logs(self, stream_name: str) -> str:
:return: string of all logs from the given log stream
"""
try:
events = list(self.hook.get_log_events(log_group=self.log_group, log_stream_name=stream_name))
return '\n'.join(reversed([event['message'] for event in events]))
events = list(
self.hook.get_log_events(
log_group=self.log_group, log_stream_name=stream_name, start_from_head=True
)
)
return '\n'.join([event['message'] for event in events])
except Exception: # pylint: disable=broad-except
msg = 'Could not read remote logs from log_group: {} log_stream: {}.'.format(
self.log_group, stream_name
Expand Down
12 changes: 5 additions & 7 deletions airflow/providers/amazon/aws/operators/datasync.py
Expand Up @@ -261,7 +261,7 @@ def choose_task(self, task_arn_list: list) -> Optional[str]:
return random.choice(task_arn_list)
raise AirflowException(f"Unable to choose a Task from {task_arn_list}")

def choose_location(self, location_arn_list: List[str]) -> Optional[str]:
def choose_location(self, location_arn_list: Optional[List[str]]) -> Optional[str]:
"""Select 1 DataSync LocationArn from a list"""
if not location_arn_list:
return None
Expand All @@ -277,9 +277,6 @@ def choose_location(self, location_arn_list: List[str]) -> Optional[str]:

def _create_datasync_task(self) -> None:
"""Create a AWS DataSyncTask."""
if not self.candidate_source_location_arns or not self.candidate_destination_location_arns:
return

hook = self.get_hook()

self.source_location_arn = self.choose_location(self.candidate_source_location_arns)
Expand Down Expand Up @@ -348,9 +345,10 @@ def _execute_datasync_task(self) -> None:
# Log some meaningful statuses
level = logging.ERROR if not result else logging.INFO
self.log.log(level, 'Status=%s', task_execution_description['Status'])
for k, v in task_execution_description['Result'].items():
if 'Status' in k or 'Error' in k:
self.log.log(level, '%s=%s', k, v)
if 'Result' in task_execution_description:
for k, v in task_execution_description['Result'].items():
if 'Status' in k or 'Error' in k:
self.log.log(level, '%s=%s', k, v)

if not result:
raise AirflowException("Failed TaskExecutionArn %s" % self.task_execution_arn)
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Expand Up @@ -463,8 +463,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'ipdb',
'jira',
'mongomock',
'moto==1.3.14', # TODO - fix Datasync issues to get higher version of moto:
# See: https://github.com/apache/airflow/issues/10985
'moto>=1.3.16',
'parameterized',
'paramiko',
'pipdeptree',
Expand Down
1 change: 1 addition & 0 deletions tests/providers/amazon/aws/hooks/test_batch_waiters.py
Expand Up @@ -229,6 +229,7 @@ def test_aws_batch_waiters(aws_region):
@mock_ecs
@mock_iam
@mock_logs
@pytest.mark.xfail(condition=True, reason="Inexplicable timeout issue when running this test. See PR 11020")
def test_aws_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_definition_name):
"""
Submit batch jobs and wait for various job status indicators or errors.
Expand Down
10 changes: 9 additions & 1 deletion tests/providers/amazon/aws/hooks/test_cloud_formation.py
Expand Up @@ -23,6 +23,7 @@

try:
from moto import mock_cloudformation
from moto.ec2.models import NetworkInterface as some_model
except ImportError:
mock_cloudformation = None

Expand All @@ -35,7 +36,14 @@ def setUp(self):
def create_stack(self, stack_name):
timeout = 15
template_body = json.dumps(
{'Resources': {"myResource": {"Type": "emr", "Properties": {"myProperty": "myPropertyValue"}}}}
{
'Resources': {
"myResource": {
"Type": some_model.cloudformation_type(),
"Properties": {"myProperty": "myPropertyValue"},
}
}
}
)

self.hook.create_stack(
Expand Down
21 changes: 1 addition & 20 deletions tests/providers/amazon/aws/hooks/test_datasync.py
Expand Up @@ -20,30 +20,13 @@
from unittest import mock

import boto3
from moto import mock_datasync

from airflow.exceptions import AirflowTaskTimeout
from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook


def no_datasync(x):
return x


try:
from moto import mock_datasync
from moto.datasync.models import DataSyncBackend

# ToDo: Remove after the moto>1.3.14 is released and contains following commit:
# https://github.com/spulec/moto/commit/5cfbe2bb3d24886f2b33bb4480c60b26961226fc
if "create_task" not in dir(DataSyncBackend) or "delete_task" not in dir(DataSyncBackend):
mock_datasync = no_datasync
except ImportError:
# flake8: noqa: F811
mock_datasync = no_datasync


@mock_datasync
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class TestAwsDataSyncHook(unittest.TestCase):
def test_get_conn(self):
hook = AWSDataSyncHook(aws_conn_id="aws_default")
Expand All @@ -65,7 +48,6 @@ def test_get_conn(self):

@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class TestAWSDataSyncHookMocked(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -118,7 +100,6 @@ def test_init(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:

self.assertIsNone(self.hook.conn)
self.assertFalse(self.hook.locations)
self.assertFalse(self.hook.tasks)
self.assertEqual(self.hook.wait_interval_seconds, 0)
Expand Down
10 changes: 7 additions & 3 deletions tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
Expand Up @@ -111,13 +111,17 @@ def test_write(self):
mock_emit.assert_has_calls([call(message) for message in messages])

def test_read(self):
# Confirmed via AWS Support call:
# CloudWatch events must be ordered chronologically otherwise
# boto3 put_log_event API throws InvalidParameterException
# (moto does not throw this exception)
generate_log_events(
self.conn,
self.remote_log_group,
self.remote_log_stream,
[
{'timestamp': 20000, 'message': 'Second'},
{'timestamp': 10000, 'message': 'First'},
{'timestamp': 20000, 'message': 'Second'},
{'timestamp': 30000, 'message': 'Third'},
],
)
Expand All @@ -139,8 +143,8 @@ def test_read_wrong_log_stream(self):
self.remote_log_group,
'alternate_log_stream',
[
{'timestamp': 20000, 'message': 'Second'},
{'timestamp': 10000, 'message': 'First'},
{'timestamp': 20000, 'message': 'Second'},
{'timestamp': 30000, 'message': 'Third'},
],
)
Expand All @@ -163,8 +167,8 @@ def test_read_wrong_log_group(self):
'alternate_log_group',
self.remote_log_stream,
[
{'timestamp': 20000, 'message': 'Second'},
{'timestamp': 10000, 'message': 'First'},
{'timestamp': 20000, 'message': 'Second'},
{'timestamp': 30000, 'message': 'Third'},
],
)
Expand Down
4 changes: 3 additions & 1 deletion tests/providers/amazon/aws/log/test_s3_task_handler.py
Expand Up @@ -20,6 +20,8 @@
import unittest
from unittest import mock

from botocore.exceptions import ClientError

from airflow.models import DAG, TaskInstance
from airflow.operators.dummy_operator import DummyOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -216,5 +218,5 @@ def test_close_no_upload(self):
self.assertFalse(self.s3_task_handler.upload_on_close)
self.s3_task_handler.close()

with self.assertRaises(self.conn.exceptions.NoSuchKey):
with self.assertRaises(ClientError):
boto3.resource('s3').Object('bucket', self.remote_log_key).get() # pylint: disable=no-member
63 changes: 27 additions & 36 deletions tests/providers/amazon/aws/operators/test_datasync.py
Expand Up @@ -18,6 +18,7 @@
from unittest import mock

import boto3
from moto import mock_datasync

from airflow.exceptions import AirflowException
from airflow.models import DAG, TaskInstance
Expand All @@ -26,23 +27,6 @@
from airflow.utils import timezone
from airflow.utils.timezone import datetime


def no_datasync(x):
return x


try:
from moto import mock_datasync
from moto.datasync.models import DataSyncBackend

# ToDo: Remove after the moto>1.3.14 is released and contains following commit:
# https://github.com/spulec/moto/commit/5cfbe2bb3d24886f2b33bb4480c60b26961226fc
if "create_task" not in dir(DataSyncBackend) or "delete_task" not in dir(DataSyncBackend):
mock_datasync = no_datasync
except ImportError:
# flake8: noqa: F811
mock_datasync = no_datasync

TEST_DAG_ID = "unit_tests"
DEFAULT_DATE = datetime(2018, 1, 1)

Expand Down Expand Up @@ -82,7 +66,6 @@ def no_datasync(x):

@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class AWSDataSyncTestCaseBase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -128,18 +111,18 @@ def tearDown(self):

@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class TestAWSDataSyncOperatorCreate(AWSDataSyncTestCaseBase):
def set_up_operator(
self,
task_id="test_aws_datasync_create_task_operator",
task_arn=None,
source_location_uri=SOURCE_LOCATION_URI,
destination_location_uri=DESTINATION_LOCATION_URI,
allow_random_location_choice=False,
):
# Create operator
self.datasync = AWSDataSyncOperator(
task_id="test_aws_datasync_create_task_operator",
task_id=task_id,
dag=self.dag,
task_arn=task_arn,
source_location_uri=source_location_uri,
Expand Down Expand Up @@ -285,19 +268,28 @@ def test_dont_create_task(self, mock_get_conn):
# ### Check mocks:
mock_get_conn.assert_called()

def create_task_many_locations(self, mock_get_conn):
def test_create_task_many_locations(self, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
# ### Begin tests:

# Delete all tasks:
tasks = self.client.list_tasks()
for task in tasks["Tasks"]:
self.client.delete_task(TaskArn=task["TaskArn"])
# Create duplicate source location to choose from
self.client.create_location_smb(**MOCK_DATA["create_source_location_kwargs"])

self.set_up_operator(task_arn=self.task_arn)
self.set_up_operator(task_id='datasync_task1')
with self.assertRaises(AirflowException):
self.datasync.execute(None)

self.set_up_operator(task_arn=self.task_arn, allow_random_location_choice=True)
# Delete all tasks:
tasks = self.client.list_tasks()
for task in tasks["Tasks"]:
self.client.delete_task(TaskArn=task["TaskArn"])

self.set_up_operator(task_id='datasync_task2', allow_random_location_choice=True)
self.datasync.execute(None)
# ### Check mocks:
mock_get_conn.assert_called()
Expand Down Expand Up @@ -335,18 +327,18 @@ def test_xcom_push(self, mock_get_conn):

@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class TestAWSDataSyncOperatorGetTasks(AWSDataSyncTestCaseBase):
def set_up_operator(
self,
task_id="test_aws_datasync_get_tasks_operator",
task_arn=None,
source_location_uri=SOURCE_LOCATION_URI,
destination_location_uri=DESTINATION_LOCATION_URI,
allow_random_task_choice=False,
):
# Create operator
self.datasync = AWSDataSyncOperator(
task_id="test_aws_datasync_get_tasks_operator",
task_id=task_id,
dag=self.dag,
task_arn=task_arn,
source_location_uri=source_location_uri,
Expand Down Expand Up @@ -468,7 +460,7 @@ def test_get_many_tasks(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:

self.set_up_operator()
self.set_up_operator(task_id='datasync_task1')

self.client.create_task(
SourceLocationArn=self.source_location_arn,
Expand All @@ -491,7 +483,7 @@ def test_get_many_tasks(self, mock_get_conn):
locations = self.client.list_locations()
self.assertEqual(len(locations["Locations"]), 2)

self.set_up_operator(task_arn=self.task_arn, allow_random_task_choice=True)
self.set_up_operator(task_id='datasync_task2', task_arn=self.task_arn, allow_random_task_choice=True)
self.datasync.execute(None)
# ### Check mocks:
mock_get_conn.assert_called()
Expand Down Expand Up @@ -529,20 +521,21 @@ def test_xcom_push(self, mock_get_conn):

@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class TestAWSDataSyncOperatorUpdate(AWSDataSyncTestCaseBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.datasync = None

def set_up_operator(self, task_arn="self", update_task_kwargs="default"):
def set_up_operator(
self, task_id="test_aws_datasync_update_task_operator", task_arn="self", update_task_kwargs="default"
):
if task_arn == "self":
task_arn = self.task_arn
if update_task_kwargs == "default":
update_task_kwargs = {"Options": {"VerifyMode": "BEST_EFFORT", "Atime": "NONE"}}
# Create operator
self.datasync = AWSDataSyncOperator(
task_id="test_aws_datasync_update_task_operator",
task_id=task_id,
dag=self.dag,
task_arn=task_arn,
update_task_kwargs=update_task_kwargs,
Expand Down Expand Up @@ -628,14 +621,13 @@ def test_xcom_push(self, mock_get_conn):

@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class TestAWSDataSyncOperator(AWSDataSyncTestCaseBase):
def set_up_operator(self, task_arn="self"):
def set_up_operator(self, task_id="test_aws_datasync_task_operator", task_arn="self"):
if task_arn == "self":
task_arn = self.task_arn
# Create operator
self.datasync = AWSDataSyncOperator(
task_id="test_aws_datasync_task_operator",
task_id=task_id,
dag=self.dag,
wait_interval_seconds=0,
task_arn=task_arn,
Expand Down Expand Up @@ -782,14 +774,13 @@ def test_xcom_push(self, mock_get_conn):

@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143
class TestAWSDataSyncOperatorDelete(AWSDataSyncTestCaseBase):
def set_up_operator(self, task_arn="self"):
def set_up_operator(self, task_id="test_aws_datasync_delete_task_operator", task_arn="self"):
if task_arn == "self":
task_arn = self.task_arn
# Create operator
self.datasync = AWSDataSyncOperator(
task_id="test_aws_datasync_delete_task_operator",
task_id=task_id,
dag=self.dag,
task_arn=task_arn,
delete_task_after_execution=True,
Expand Down

0 comments on commit 663259d

Please sign in to comment.