Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def wait_for_task_execution(self, task_execution_arn, max_iterations=2 * 180):
TaskExecutionArn=task_execution_arn
)
status = task_execution["Status"]
self.log.info("status=%s", status)
self.log.info("TaskExecution status=%s", status)
iterations = iterations - 1
if status in self.TASK_EXECUTION_FAILURE_STATES:
break
Expand Down
98 changes: 88 additions & 10 deletions airflow/providers/amazon/aws/operators/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import logging
import random
import time
from typing import List

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -105,6 +107,12 @@ class AWSDataSyncOperator(BaseOperator):
)
ui_color = "#44b5e2"

# Control when we execute a Task, based on initial Task status
TASK_STATUS_WAIT_BEFORE_START: List[str] = ['CREATING']
TASK_STATUS_START: List[str] = ['AVAILABLE']
TASK_STATUS_SKIP_START: List[str] = []
TASK_STATUS_FAIL: List[str] = ['UNAVAILABLE', 'QUEUED', 'RUNNING']

@apply_defaults
def __init__(
self,
Expand Down Expand Up @@ -174,6 +182,7 @@ def __init__(
self.source_location_arn = None
self.destination_location_arn = None
self.task_execution_arn = None
self.task_status = None

def get_hook(self):
"""Create and return AWSDataSyncHook.
Expand Down Expand Up @@ -208,12 +217,41 @@ def execute(self, context):

self.log.info("Using DataSync TaskArn %s", self.task_arn)

# Update the DataSync Task
# Update the DataSync Task definition
if self.update_task_kwargs:
self._update_datasync_task()

# Execute the DataSync Task
self._execute_datasync_task()
# Wait for the Task to be in a valid state to Start
self.task_status = self._wait_get_status_before_start()

self.log.info('Task status is %s.', self.task_status)
if self.task_status in self.TASK_STATUS_START:
self.log.info(
'The Task will be started because its status is in %s.',
self.TASK_STATUS_START)
# Start the DataSync Task
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Start the DataSync Task

self._start_datasync_task()
elif self.task_status in self.TASK_STATUS_SKIP_START:
self.log.info(
'The Task will NOT be started because its status is in %s.',
self.TASK_STATUS_SKIP_START)
if not self.task_execution_arn:
task_description = self.get_hook().get_task_description(self.task_arn)
if 'CurrentTaskExecutionArn' in task_description:
self.task_execution_arn = task_description['CurrentTaskExecutionArn']
else:
raise AirflowException(
'Starting the Task was skipped,'
' but no CurrentTaskExecutionArn was found.')
elif self.task_status in self.TASK_STATUS_FAIL:
raise AirflowException(
'Task cannot be started because its status is in %s.'
% self.TASK_STATUS_FAIL
)
Comment on lines +247 to +250
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise AirflowException(
'Task cannot be started because its status is in %s.'
% self.TASK_STATUS_FAIL
)
raise AirflowException(
f'Task cannot be started because its status is in {self.TASK_STATUS_FAIL}.'
)

Please use f-strings :)

else:
raise AirflowException('Unexpected task status %s.' % self.task_status)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise AirflowException('Unexpected task status %s.' % self.task_status)
raise AirflowException(f'Unexpected task status {self.task_status}')


self._wait_for_datasync_task_execution()

if not self.task_execution_arn:
raise AirflowException("Nothing was executed")
Expand All @@ -222,7 +260,10 @@ def execute(self, context):
if self.delete_task_after_execution:
self._delete_datasync_task()

return {"TaskArn": self.task_arn, "TaskExecutionArn": self.task_execution_arn}
return {
"TaskArn": self.task_arn,
"TaskExecutionArn": self.task_execution_arn
}

def _get_tasks_and_locations(self):
"""Find existing DataSync Task based on source and dest Locations."""
Expand Down Expand Up @@ -331,16 +372,52 @@ def _update_datasync_task(self):
self.log.info("Updated TaskArn %s", self.task_arn)
return self.task_arn

def _execute_datasync_task(self):
"""Create and monitor an AWSDataSync TaskExecution for a Task."""
hook = self.get_hook()
def _wait_get_status_before_start(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _wait_get_status_before_start(
def _wait_for_status(

self,
max_iterations=12 * 180): # wait_interval_seconds*12*180=180 minutes by default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_iterations=12 * 180): # wait_interval_seconds*12*180=180 minutes by default
max_iterations : int = 12 * 180) -> str:

You already have that comment in the docs below. Even better it would be to add :param and add the comment there.

"""
Wait until the Task can be started.

The Task can be started when its Status is not in TASK_STATUS_WAIT_BEFORE_START
Uses wait_interval_seconds (which is also used while waiting for TaskExecution)
So, max_iterations=12*180 gives 180 minutes wait by default.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add :returns: here.

hook = self.get_hook()
task_status = hook.get_task_description(self.task_arn)['Status']
iteration = 0
while task_status in self.TASK_STATUS_WAIT_BEFORE_START:
self.log.info(
'Task status is %s.'
' Waiting for it to not be %s.'
' Iteration %s/%s.',
task_status,
self.TASK_STATUS_WAIT_BEFORE_START,
iteration,
max_iterations)
time.sleep(self.wait_interval_seconds)
task_status = hook.get_task_description(self.task_arn)['Status']
iteration = iteration + 1
if iteration >= max_iterations:
break

return task_status
Comment on lines +385 to +403
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hook = self.get_hook()
task_status = hook.get_task_description(self.task_arn)['Status']
iteration = 0
while task_status in self.TASK_STATUS_WAIT_BEFORE_START:
self.log.info(
'Task status is %s.'
' Waiting for it to not be %s.'
' Iteration %s/%s.',
task_status,
self.TASK_STATUS_WAIT_BEFORE_START,
iteration,
max_iterations)
time.sleep(self.wait_interval_seconds)
task_status = hook.get_task_description(self.task_arn)['Status']
iteration = iteration + 1
if iteration >= max_iterations:
break
return task_status
hook = self.get_hook()
for iteration in range(max_iterations):
task_status = hook.get_task_description(self.task_arn)['Status']
self.log.info(
'Task status is %s.'
' Waiting for it to not be %s.'
' Iteration %s/%s.',
task_status,
self.TASK_STATUS_WAIT_BEFORE_START,
iteration,
max_iterations)
if task_status not in self.TASK_STATUS_WAIT_BEFORE_START:
break
time.sleep(self.wait_interval_seconds)
return task_status

WDYT?


def _start_datasync_task(self):
"""Create an AWSDataSync TaskExecution for a Task."""
hook = self.get_hook()
# Create a task execution:
self.log.info("Starting execution for TaskArn %s", self.task_arn)
self.task_execution_arn = hook.start_task_execution(
self.task_arn, **self.task_execution_kwargs)
self.log.info("Started TaskExecutionArn %s", self.task_execution_arn)

def _wait_for_datasync_task_execution(self):
"""Monitor an AWSDataSync TaskExecution for a Task."""
hook = self.get_hook()
if not self.task_execution_arn:
raise AirflowException(
'Unable to wait for TaskExecutionArn to complete'
' because none was provided')
# Wait for task execution to complete
self.log.info("Waiting for TaskExecutionArn %s",
self.task_execution_arn)
Expand All @@ -355,9 +432,10 @@ def _execute_datasync_task(self):
# Log some meaningful statuses
level = logging.ERROR if not result else logging.INFO
self.log.log(level, 'Status=%s', task_execution_description['Status'])
for k, v in task_execution_description['Result'].items():
if 'Status' in k or 'Error' in k:
self.log.log(level, '%s=%s', k, v)
if 'Result' in task_execution_description:
for k, v in task_execution_description['Result'].items():
if 'Status' in k or 'Error' in k:
self.log.log(level, '%s=%s', k, v)

if not result:
raise AirflowException(
Expand Down
144 changes: 144 additions & 0 deletions tests/providers/amazon/aws/operators/test_datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def no_datasync(x):
from moto.datasync.models import DataSyncBackend
# ToDo: Remove after the moto>1.3.14 is released and contains following commit:
# https://github.com/spulec/moto/commit/5cfbe2bb3d24886f2b33bb4480c60b26961226fc
# For now testing is done (not skipped) using a local pip install of the latest
# moto dev version, eg 'pip install moto==1.3.15.dev432'. See the moto release page.
if "create_task" not in dir(DataSyncBackend) or "delete_task" not in dir(DataSyncBackend):
mock_datasync = no_datasync
except ImportError:
Expand Down Expand Up @@ -916,3 +918,145 @@ def test_xcom_push(self, mock_get_conn):
self.assertEqual(pushed_task_arn, self.task_arn)
# ### Check mocks:
mock_get_conn.assert_called()


@mock_datasync
@mock.patch.object(AWSDataSyncHook, "get_conn")
@mock.patch.object(AWSDataSyncHook, "get_task_description")
@unittest.skipIf(
mock_datasync == no_datasync, "moto datasync package missing"
) # pylint: disable=W0143
class TestAWSDataSyncOperatorTaskStatus(AWSDataSyncTestCaseBase):
def set_up_operator(self, task_arn="self"):
if task_arn == "self":
task_arn = self.task_arn
# Create operator
self.datasync = AWSDataSyncOperator(
task_id="test_aws_datasync_task_status",
dag=self.dag,
task_arn=task_arn,
wait_interval_seconds=0,
)

def test_task_status_wait_before_start_wait(self, mock_get, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
# Set the Task Status to CREATING and transition it to AVAILABLE thereafter
mock_get.side_effect = [
{'Status': 'CREATING'},
{'Status': 'CREATING'},
{'Status': 'AVAILABLE'}
]
# ### Begin tests:
self.set_up_operator()
self.datasync.execute(None)

self.assertEqual(mock_get.call_count, 3)

def test_task_status_wait_before_start_no_wait(self, mock_get, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
mock_get.side_effect = [
{'Status': 'AVAILABLE'},
{'Status': 'AVAILABLE'},
{'Status': 'AVAILABLE'}
]
# ### Begin tests:
self.set_up_operator()
self.datasync.execute(None)

self.assertEqual(mock_get.call_count, 1)

def test_task_status_custom(self, mock_get, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
mock_get.side_effect = [
{'Status': 'CUSTOM'},
{'Status': 'AVAILABLE'}
]
# ### Begin tests:
self.set_up_operator()
self.datasync.TASK_STATUS_WAIT_BEFORE_START = ['CUSTOM', 'WAIT', 'STATUSES']
self.datasync.execute(None)

self.assertEqual(mock_get.call_count, 2)

def test_task_status_start(self, mock_get, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
mock_get.side_effect = [{
'Status': 'AVAILABLE'
}]
# ### Begin tests:
task_arn = self.client.create_task(
SourceLocationArn=self.source_location_arn,
DestinationLocationArn=self.destination_location_arn,
)["TaskArn"]

self.set_up_operator(task_arn=task_arn)
result = self.datasync.execute(None)

self.assertEqual(result["TaskArn"], task_arn)
self.assertEqual(self.datasync.task_arn, task_arn)
# ### Check mocks:
mock_get_conn.assert_called()
mock_get.assert_called_with(task_arn)

@mock.patch.object(AWSDataSyncOperator, "_start_datasync_task")
def test_task_status_skip_start(self, mock_start, mock_get, mock_get_conn):
# Create and start a Task
task_arn = self.client.create_task(
SourceLocationArn=self.source_location_arn,
DestinationLocationArn=self.destination_location_arn,
)["TaskArn"]
task_execution_arn = self.client.start_task_execution(TaskArn=task_arn)['TaskExecutionArn']
# ### Set up mocks:
mock_get_conn.return_value = self.client
mock_get.side_effect = [
{'Status': 'RUNNING', 'CurrentTaskExecutionArn': task_execution_arn},
{'Status': 'RUNNING', 'CurrentTaskExecutionArn': task_execution_arn}
]
# ### Begin tests:
self.set_up_operator(task_arn=task_arn)
self.datasync.TASK_STATUS_SKIP_START = ['RUNNING']
result = self.datasync.execute(None)

self.assertEqual(result["TaskArn"], task_arn)
self.assertEqual(self.datasync.task_arn, task_arn)
mock_start.assert_not_called()

@mock.patch.object(AWSDataSyncOperator, "_start_datasync_task")
def test_task_status_skip_start_fail(self, mock_start, mock_get, mock_get_conn):
# Create and start a Task
task_arn = self.client.create_task(
SourceLocationArn=self.source_location_arn,
DestinationLocationArn=self.destination_location_arn,
)["TaskArn"]
# ### Set up mocks:
mock_get_conn.return_value = self.client
mock_get.side_effect = [
{'Status': 'RUNNING'},
{'Status': 'RUNNING'}
]
# ### Begin tests:
self.set_up_operator(task_arn=task_arn)
with self.assertRaises(AirflowException):
self.datasync.execute(None)

mock_start.assert_not_called()

def test_task_status_fail_start(self, mock_get, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
mock_get.side_effect = [
{'Status': 'INVALID'}
]
# ### Begin tests:
task_arn = self.client.create_task(
SourceLocationArn=self.source_location_arn,
DestinationLocationArn=self.destination_location_arn,
)["TaskArn"]

self.set_up_operator(task_arn=task_arn)
with self.assertRaises(AirflowException):
self.datasync.execute(None)