Skip to content

Commit

Permalink
Extend hooks arguments into AwsBaseWaiterTrigger (#34884)
Browse files Browse the repository at this point in the history
* Extend hooks arguments into `AwsBaseWaiterTrigger`

* Use prune dictionary AwsBaseWaiterTrigger

---------
Co-authored-by: Vincent Beck <vincbeck@amazon.com>

* Add links to boto3 documentation in docstring

* Add super() into the AwsBaseWaiterTrigger
  • Loading branch information
Taragolis committed Oct 12, 2023
1 parent 8e06897 commit 545e4d5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
27 changes: 24 additions & 3 deletions airflow/providers/amazon/aws/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
Expand Down Expand Up @@ -55,6 +56,11 @@ class AwsBaseWaiterTrigger(BaseTrigger):
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
:param region_name: The AWS region where the resources to watch are. To be used to build the hook.
:param verify: Whether or not to verify SSL certificates. To be used to build the hook.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client.
To be used to build the hook. For available key-values see:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

def __init__(
Expand All @@ -72,7 +78,10 @@ def __init__(
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
super().__init__()
# parameters that should be hardcoded in the child's implem
self.serialized_fields = serialized_fields

Expand All @@ -90,6 +99,8 @@ def __init__(
self.attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
# here we put together the "common" params,
Expand All @@ -102,9 +113,19 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
**self.serialized_fields,
)
if self.region_name:
# if we serialize the None value from this, it breaks subclasses that don't have it in their ctor.
params["region_name"] = self.region_name

# if we serialize the None value from this, it breaks subclasses that don't have it in their ctor.
params.update(
prune_dict(
{
# Keep previous behaviour when empty string in region_name evaluated as `None`
"region_name": self.region_name or None,
"verify": self.verify,
"botocore_config": self.botocore_config,
}
)
)

return (
# remember that self is an instance of the subclass here, not of this class.
self.__class__.__module__ + "." + self.__class__.__qualname__,
Expand Down
36 changes: 35 additions & 1 deletion tests/providers/amazon/aws/triggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,41 @@ def test_region_serialized(self):
assert "region_name" in args
assert args["region_name"] == "my_region"

def test_region_not_serialized_if_omitted(self):
@pytest.mark.parametrize("verify", [True, False, pytest.param("/foo/bar.pem", id="path")])
def test_verify_serialized(self, verify):
self.trigger.verify = verify
_, args = self.trigger.serialize()

assert "verify" in args
assert args["verify"] == verify

@pytest.mark.parametrize(
"botocore_config",
[
pytest.param({"read_timeout": 10, "connect_timeout": 42, "keepalive": True}, id="non-empty-dict"),
pytest.param({}, id="empty-dict"),
],
)
def test_botocore_config_serialized(self, botocore_config):
self.trigger.botocore_config = botocore_config
_, args = self.trigger.serialize()

assert "botocore_config" in args
assert args["botocore_config"] == botocore_config

@pytest.mark.parametrize("param_name", ["region_name", "verify", "botocore_config"])
def test_hooks_args_not_serialized_if_omitted(self, param_name):
_, args = self.trigger.serialize()

assert param_name not in args

def test_region_name_not_serialized_if_empty_string(self):
"""
Compatibility with previous behaviour when empty string region name not serialised.
It would evaluate as None, however empty string it is not valid region name in boto3.
"""
self.trigger.region_name = ""
_, args = self.trigger.serialize()

assert "region_name" not in args
Expand Down

0 comments on commit 545e4d5

Please sign in to comment.