Skip to content

Commit

Permalink
Use base aws classes in Amazon ECS Operators/Sensors/Triggers (#36393)
Browse files Browse the repository at this point in the history
* Use base aws classes in Amazon ECS Operators/Sensors/Triggers

* remove redundant init and wrapper for the hook that essentially did nothing

* split out half of test

* change test to properly mock hook

* format
  • Loading branch information
jayceslesar committed Dec 26, 2023
1 parent 2bd6077 commit 73d8794
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 49 deletions.
19 changes: 5 additions & 14 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterActiveTrigger,
ClusterInactiveTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict

Expand All @@ -45,21 +46,11 @@
from airflow.models import TaskInstance
from airflow.utils.context import Context

DEFAULT_CONN_ID = "aws_default"


class EcsBaseOperator(BaseOperator):
class EcsBaseOperator(AwsBaseOperator[EcsHook]):
"""This is the base operator for all Elastic Container Service operators."""

def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs):
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)

@cached_property
def hook(self) -> EcsHook:
"""Create and return an EcsHook."""
return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
aws_hook_class = EcsHook

@cached_property
def client(self) -> boto3.client:
Expand Down Expand Up @@ -101,7 +92,7 @@ class EcsCreateClusterOperator(EcsBaseOperator):
(default: False)
"""

template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"cluster_name",
"create_cluster_kwargs",
"wait_for_completion",
Expand Down
23 changes: 7 additions & 16 deletions airflow/providers/amazon/aws/sensors/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
EcsTaskDefinitionStates,
EcsTaskStates,
)
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
import boto3

from airflow.utils.context import Context

DEFAULT_CONN_ID: str = "aws_default"


def _check_failed(current_state, target_state, failure_states, soft_fail: bool) -> None:
if (current_state != target_state) and (current_state in failure_states):
Expand All @@ -45,18 +44,10 @@ def _check_failed(current_state, target_state, failure_states, soft_fail: bool)
raise AirflowException(message)


class EcsBaseSensor(BaseSensorOperator):
class EcsBaseSensor(AwsBaseSensor[EcsHook]):
"""Contains general sensor behavior for Elastic Container Service."""

def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs):
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)

@cached_property
def hook(self) -> EcsHook:
"""Create and return an EcsHook."""
return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
aws_hook_class = EcsHook

@cached_property
def client(self) -> boto3.client:
Expand All @@ -78,7 +69,7 @@ class EcsClusterStateSensor(EcsBaseSensor):
Success State. (Default: "FAILED" or "INACTIVE")
"""

template_fields: Sequence[str] = ("cluster_name", "target_state", "failure_states")
template_fields: Sequence[str] = aws_template_fields("cluster_name", "target_state", "failure_states")

def __init__(
self,
Expand Down Expand Up @@ -116,7 +107,7 @@ class EcsTaskDefinitionStateSensor(EcsBaseSensor):
:param target_state: Success state to watch for. (Default: "ACTIVE")
"""

template_fields: Sequence[str] = ("task_definition", "target_state", "failure_states")
template_fields: Sequence[str] = aws_template_fields("task_definition", "target_state", "failure_states")

def __init__(
self,
Expand Down Expand Up @@ -162,7 +153,7 @@ class EcsTaskStateSensor(EcsBaseSensor):
the Success State. (Default: "STOPPED")
"""

template_fields: Sequence[str] = ("cluster", "task", "target_state", "failure_states")
template_fields: Sequence[str] = aws_template_fields("cluster", "task", "target_state", "failure_states")

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/amazon/aws/triggers/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
Expand All @@ -66,6 +67,7 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
Expand All @@ -91,6 +93,7 @@ def __init__(
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
Expand All @@ -104,6 +107,7 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
Expand Down
4 changes: 4 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/ecs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------
.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
31 changes: 14 additions & 17 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
from airflow.providers.amazon.aws.operators.ecs import (
DEFAULT_CONN_ID,
EcsBaseOperator,
EcsCreateClusterOperator,
EcsDeleteClusterOperator,
Expand Down Expand Up @@ -112,30 +111,28 @@ def test_initialise_operator(self, aws_conn_id, region_name):
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)

assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default")
assert op.region == (region_name if region_name is not NOTSET else None)

@mock.patch("airflow.providers.amazon.aws.operators.ecs.EcsHook")
@pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
@pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
def test_hook_and_client(self, mock_ecs_hook_cls, aws_conn_id, region_name):
"""Test initialize ``EcsHook`` and ``boto3.client``."""
mock_ecs_hook = mock_ecs_hook_cls.return_value
mock_conn = mock.MagicMock()
type(mock_ecs_hook).conn = mock.PropertyMock(return_value=mock_conn)

def test_initialise_operator_hook(self, aws_conn_id, region_name):
"""Test initialize operator."""
op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseOperator(task_id="test_ecs_base_hook_client", **op_kw)
op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)

assert op.hook.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default")
assert op.hook.region_name == (region_name if region_name is not NOTSET else None)

hook = op.hook
assert op.hook is hook
mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
with mock.patch.object(EcsBaseOperator, "hook", new_callable=mock.PropertyMock) as m:
mocked_hook = mock.MagicMock(name="MockHook")
mocked_client = mock.MagicMock(name="Mocklient")
mocked_hook.conn = mocked_client
m.return_value = mocked_hook

client = op.client
mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
assert client == mock_conn
assert op.client is client
assert op.client == mocked_client
m.assert_called_once()


class TestEcsRunTaskOperator(EcsBaseTestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/amazon/aws/sensors/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.ecs import (
DEFAULT_CONN_ID,
EcsBaseSensor,
EcsClusterStates,
EcsClusterStateSensor,
Expand Down Expand Up @@ -79,7 +78,7 @@ def test_initialise_operator(self, aws_conn_id, region_name):
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseSensor(task_id="test_ecs_base", **op_kw)

assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default")
assert op.region == (region_name if region_name is not NOTSET else None)

@pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
Expand Down

0 comments on commit 73d8794

Please sign in to comment.