Skip to content

Commit

Permalink
Introduce a base class for aws triggers (#32274)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandonr-amz committed Jul 7, 2023
1 parent 6c854dc commit 05f1acf
Show file tree
Hide file tree
Showing 32 changed files with 1,121 additions and 2,669 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/athena.py
Expand Up @@ -253,7 +253,7 @@ def poll_query_status(
try:
wait(
waiter=self.get_waiter("query_complete"),
waiter_delay=sleep_time or self.sleep_time,
waiter_delay=self.sleep_time if sleep_time is None else sleep_time,
waiter_max_attempts=max_polling_attempts or 120,
args={"QueryExecutionId": query_execution_id},
failure_message=f"Error while waiting for query {query_execution_id} to complete",
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/operators/batch.py
Expand Up @@ -41,7 +41,7 @@
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.triggers.batch import (
BatchCreateComputeEnvironmentTrigger,
BatchOperatorTrigger,
BatchJobTrigger,
)
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
Expand Down Expand Up @@ -221,12 +221,12 @@ def execute(self, context: Context):
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=BatchOperatorTrigger(
trigger=BatchJobTrigger(
job_id=self.job_id,
max_retries=self.max_retries or 10,
waiter_max_attempts=self.max_retries or 10,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down
16 changes: 9 additions & 7 deletions airflow/providers/amazon/aws/operators/ecs.py
Expand Up @@ -33,7 +33,11 @@
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterActiveTrigger,
ClusterInactiveTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -139,13 +143,12 @@ def execute(self, context: Context):
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_active",
trigger=ClusterActiveTrigger(
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -217,13 +220,12 @@ def execute(self, context: Context):
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_inactive",
trigger=ClusterInactiveTrigger(
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down
13 changes: 6 additions & 7 deletions airflow/providers/amazon/aws/operators/eks.py
Expand Up @@ -31,8 +31,9 @@
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateFargateProfileTrigger,
EksCreateNodegroupTrigger,
EksDeleteFargateProfileTrigger,
EksNodegroupTrigger,
EksDeleteNodegroupTrigger,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
Expand Down Expand Up @@ -413,12 +414,11 @@ def execute(self, context: Context):

if self.deferrable:
self.defer(
trigger=EksNodegroupTrigger(
waiter_name="nodegroup_active",
trigger=EksCreateNodegroupTrigger(
cluster_name=self.cluster_name,
nodegroup_name=self.nodegroup_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
Expand Down Expand Up @@ -711,12 +711,11 @@ def execute(self, context: Context):
eks_hook.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=self.nodegroup_name)
if self.deferrable:
self.defer(
trigger=EksNodegroupTrigger(
waiter_name="nodegroup_deleted",
trigger=EksDeleteNodegroupTrigger(
cluster_name=self.cluster_name,
nodegroup_name=self.nodegroup_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/emr.py
Expand Up @@ -577,7 +577,7 @@ def execute(self, context: Context) -> str | None:
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -943,8 +943,8 @@ def execute(self, context: Context) -> None:
self.defer(
trigger=EmrTerminateJobFlowTrigger(
job_flow_id=self.job_flow_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/glue_crawler.py
Expand Up @@ -96,7 +96,7 @@ def execute(self, context: Context):
self.defer(
trigger=GlueCrawlerCompleteTrigger(
crawler_name=crawler_name,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down
20 changes: 10 additions & 10 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Expand Up @@ -267,8 +267,8 @@ def execute(self, context: Context):
self.defer(
trigger=RedshiftCreateClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempt=self.max_attempt,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -361,8 +361,8 @@ def execute(self, context: Context) -> Any:
self.defer(
trigger=RedshiftCreateClusterSnapshotTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempt,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -510,8 +510,8 @@ def execute(self, context: Context):
self.defer(
trigger=RedshiftResumeClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -598,8 +598,8 @@ def execute(self, context: Context):
self.defer(
trigger=RedshiftPauseClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -690,8 +690,8 @@ def execute(self, context: Context):
timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60),
trigger=RedshiftDeleteClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down
14 changes: 8 additions & 6 deletions airflow/providers/amazon/aws/sensors/batch.py
Expand Up @@ -25,7 +25,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,11 +98,12 @@ def execute(self, context: Context) -> None:
)
self.defer(
timeout=timeout,
trigger=BatchSensorTrigger(
trigger=BatchJobTrigger(
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poke_interval=self.poke_interval,
waiter_delay=int(self.poke_interval),
waiter_max_attempts=self.max_retries,
),
method_name="execute_complete",
)
Expand All @@ -113,9 +114,10 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
Relies on trigger to throw an exception, otherwise it assumes execution was successful.
"""
if "status" in event and event["status"] == "failure":
raise AirflowException(event["message"])
self.log.info(event["message"])
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
job_id = event["job_id"]
self.log.info("Batch Job %s complete", job_id)

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> BatchClientHook:
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/amazon/aws/sensors/emr.py
Expand Up @@ -316,7 +316,7 @@ def execute(self, context: Context):
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -501,9 +501,9 @@ def execute(self, context: Context) -> None:
timeout=timedelta(seconds=self.poke_interval * self.max_attempts),
trigger=EmrTerminateJobFlowTrigger(
job_flow_id=self.job_flow_id,
max_attempts=self.max_attempts,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
poll_interval=int(self.poke_interval),
waiter_delay=int(self.poke_interval),
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -628,9 +628,9 @@ def execute(self, context: Context) -> None:
trigger=EmrStepSensorTrigger(
job_flow_id=self.job_flow_id,
step_id=self.step_id,
waiter_delay=int(self.poke_interval),
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
max_attempts=self.max_attempts,
poke_interval=int(self.poke_interval),
),
method_name="execute_complete",
)
Expand Down
58 changes: 20 additions & 38 deletions airflow/providers/amazon/aws/triggers/athena.py
Expand Up @@ -16,61 +16,43 @@
# under the License.
from __future__ import annotations

from typing import Any

from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger


class AthenaTrigger(BaseTrigger):
class AthenaTrigger(AwsBaseWaiterTrigger):
"""
Trigger for RedshiftCreateClusterOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `available` state.
:param query_execution_id: ID of the Athena query execution to watch
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempt: The maximum number of attempts to be made.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
query_execution_id: str,
poll_interval: int,
max_attempt: int,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
):
self.query_execution_id = query_execution_id
self.poll_interval = poll_interval
self.max_attempt = max_attempt
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"query_execution_id": str(self.query_execution_id),
"poll_interval": str(self.poll_interval),
"max_attempt": str(self.max_attempt),
"aws_conn_id": str(self.aws_conn_id),
},
super().__init__(
serialized_fields={"query_execution_id": query_execution_id},
waiter_name="query_complete",
waiter_args={"QueryExecutionId": query_execution_id},
failure_message=f"Error while waiting for query {query_execution_id} to complete",
status_message=f"Query execution id: {query_execution_id}",
status_queries=["QueryExecution.Status"],
return_value=query_execution_id,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)

async def run(self):
hook = AthenaHook(self.aws_conn_id)
async with hook.async_conn as client:
waiter = hook.get_waiter("query_complete", deferrable=True, client=client)
await async_wait(
waiter=waiter,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
args={"QueryExecutionId": self.query_execution_id},
failure_message=f"Error while waiting for query {self.query_execution_id} to complete",
status_message=f"Query execution id: {self.query_execution_id}, "
"Query is still in non-terminal state",
status_args=["QueryExecution.Status.State"],
)
yield TriggerEvent({"status": "success", "value": self.query_execution_id})
def hook(self) -> AwsGenericHook:
return AthenaHook(self.aws_conn_id)

0 comments on commit 05f1acf

Please sign in to comment.