Skip to content

Commit

Permalink
Use base aws classes in Amazon EventBridge Operators (#36765)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Jan 14, 2024
1 parent 9eab3e1 commit 1e0a99c
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 92 deletions.
137 changes: 57 additions & 80 deletions airflow/providers/amazon/aws/operators/eventbridge.py
Expand Up @@ -16,19 +16,19 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.eventbridge import EventBridgeHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from airflow.utils.context import Context


class EventBridgePutEventsOperator(BaseOperator):
class EventBridgePutEventsOperator(AwsBaseOperator[EventBridgeHook]):
"""
Put Events onto Amazon EventBridge.
Expand All @@ -38,32 +38,25 @@ class EventBridgePutEventsOperator(BaseOperator):
:param entries: the list of events to be put onto EventBridge, each event is a dict (required)
:param endpoint_id: the URL subdomain of the endpoint
:param aws_conn_id: the AWS connection to use
:param region_name: the region where events are to be sent
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = ("entries", "endpoint_id", "aws_conn_id", "region_name")
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields("entries", "endpoint_id")

def __init__(
self,
*,
entries: list[dict],
endpoint_id: str | None = None,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
**kwargs,
):
def __init__(self, *, entries: list[dict], endpoint_id: str | None = None, **kwargs):
super().__init__(**kwargs)
self.entries = entries
self.endpoint_id = endpoint_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
response = self.hook.conn.put_events(
Expand All @@ -90,7 +83,7 @@ def execute(self, context: Context):
return [e["EventId"] for e in response["Entries"]]


class EventBridgePutRuleOperator(BaseOperator):
class EventBridgePutRuleOperator(AwsBaseOperator[EventBridgeHook]):
"""
Create or update a specified EventBridge rule.
Expand All @@ -106,12 +99,20 @@ class EventBridgePutRuleOperator(BaseOperator):
:param schedule_expression: the scheduling expression (for example, a cron or rate expression)
:param state: indicates whether rule is set to be "ENABLED" or "DISABLED"
:param tags: list of key-value pairs to associate with the rule
:param region: the region where rule is to be created or updated
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = (
"aws_conn_id",
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields(
"name",
"description",
"event_bus_name",
Expand All @@ -120,7 +121,6 @@ class EventBridgePutRuleOperator(BaseOperator):
"schedule_expression",
"state",
"tags",
"region_name",
)

def __init__(
Expand All @@ -134,8 +134,6 @@ def __init__(
schedule_expression: str | None = None,
state: str | None = None,
tags: list | None = None,
region_name: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -144,16 +142,9 @@ def __init__(
self.event_bus_name = event_bus_name
self.event_pattern = event_pattern
self.role_arn = role_arn
self.region_name = region_name
self.schedule_expression = schedule_expression
self.state = state
self.tags = tags
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
self.log.info('Sending rule "%s" to EventBridge.', self.name)
Expand All @@ -170,7 +161,7 @@ def execute(self, context: Context):
)


class EventBridgeEnableRuleOperator(BaseOperator):
class EventBridgeEnableRuleOperator(AwsBaseOperator[EventBridgeHook]):
"""
Enable an EventBridge Rule.
Expand All @@ -180,32 +171,25 @@ class EventBridgeEnableRuleOperator(BaseOperator):
:param name: the name of the rule to enable
:param event_bus_name: the name or ARN of the event bus associated with the rule (default if omitted)
:param aws_conn_id: the AWS connection to use
:param region_name: the region of the rule to be enabled
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", "aws_conn_id")
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields("name", "event_bus_name")

def __init__(
self,
*,
name: str,
event_bus_name: str | None = None,
region_name: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, name: str, event_bus_name: str | None = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.event_bus_name = event_bus_name
self.region_name = region_name
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
self.hook.conn.enable_rule(
Expand All @@ -220,7 +204,7 @@ def execute(self, context: Context):
self.log.info('Enabled rule "%s"', self.name)


class EventBridgeDisableRuleOperator(BaseOperator):
class EventBridgeDisableRuleOperator(AwsBaseOperator[EventBridgeHook]):
"""
Disable an EventBridge Rule.
Expand All @@ -230,32 +214,25 @@ class EventBridgeDisableRuleOperator(BaseOperator):
:param name: the name of the rule to disable
:param event_bus_name: the name or ARN of the event bus associated with the rule (default if omitted)
:param aws_conn_id: the AWS connection to use
:param region_name: the region of the rule to be disabled
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", "aws_conn_id")
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields("name", "event_bus_name")

def __init__(
self,
*,
name: str,
event_bus_name: str | None = None,
region_name: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, name: str, event_bus_name: str | None = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.event_bus_name = event_bus_name
self.region_name = region_name
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
self.hook.conn.disable_rule(
Expand Down
Expand Up @@ -31,6 +31,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
92 changes: 80 additions & 12 deletions tests/providers/amazon/aws/operators/test_eventbridge.py
Expand Up @@ -41,12 +41,28 @@

class TestEventBridgePutEventsOperator:
def test_init(self):
operator = EventBridgePutEventsOperator(
op = EventBridgePutEventsOperator(
task_id="put_events_job",
entries=ENTRIES,
aws_conn_id="fake-conn-id",
region_name="eu-central-1",
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
)

assert operator.entries == ENTRIES
assert op.entries == ENTRIES
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-central-1"
assert op.hook._verify == "/spam/egg.pem"
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgePutEventsOperator(task_id="put_events_job", entries=ENTRIES)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_execute(self, mock_conn: MagicMock):
Expand Down Expand Up @@ -83,11 +99,31 @@ def test_failed_to_send(self, mock_conn: MagicMock):

class TestEventBridgePutRuleOperator:
def test_init(self):
operator = EventBridgePutRuleOperator(
op = EventBridgePutRuleOperator(
task_id="events_put_rule_job",
name=RULE_NAME,
event_pattern=EVENT_PATTERN,
aws_conn_id="fake-conn-id",
region_name="eu-west-1",
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
)
assert op.event_pattern == EVENT_PATTERN
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify == "/spam/egg.pem"
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgePutRuleOperator(
task_id="events_put_rule_job", name=RULE_NAME, event_pattern=EVENT_PATTERN
)

assert operator.event_pattern == EVENT_PATTERN
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_execute(self, mock_conn: MagicMock):
Expand Down Expand Up @@ -117,12 +153,28 @@ def test_put_rule_with_bad_json_fails(self):

class TestEventBridgeEnableRuleOperator:
def test_init(self):
operator = EventBridgeDisableRuleOperator(
op = EventBridgeEnableRuleOperator(
task_id="enable_rule_task",
name=RULE_NAME,
aws_conn_id="fake-conn-id",
region_name="us-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)

assert operator.name == RULE_NAME
assert op.name == RULE_NAME
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "us-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgeEnableRuleOperator(task_id="enable_rule_task", name=RULE_NAME)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_enable_rule(self, mock_conn: MagicMock):
Expand All @@ -137,12 +189,28 @@ def test_enable_rule(self, mock_conn: MagicMock):

class TestEventBridgeDisableRuleOperator:
def test_init(self):
operator = EventBridgeDisableRuleOperator(
op = EventBridgeDisableRuleOperator(
task_id="disable_rule_task",
name=RULE_NAME,
aws_conn_id="fake-conn-id",
region_name="ca-west-1",
verify=True,
botocore_config={"read_timeout": 42},
)

assert operator.name == RULE_NAME
assert op.name == RULE_NAME
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "ca-west-1"
assert op.hook._verify is True
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgeDisableRuleOperator(task_id="disable_rule_task", name=RULE_NAME)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_disable_rule(self, mock_conn: MagicMock):
Expand Down

0 comments on commit 1e0a99c

Please sign in to comment.