diff --git a/airflow/providers/amazon/aws/sensors/opensearch_serverless.py b/airflow/providers/amazon/aws/sensors/opensearch_serverless.py index 4a434865489137..8943113156553a 100644 --- a/airflow/providers/amazon/aws/sensors/opensearch_serverless.py +++ b/airflow/providers/amazon/aws/sensors/opensearch_serverless.py @@ -18,9 +18,13 @@ from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.triggers.opensearch_serverless import ( + OpenSearchServerlessCollectionActiveTrigger, +) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import exactly_one @@ -42,8 +46,8 @@ class OpenSearchServerlessCollectionActiveSensor(AwsBaseSensor[OpenSearchServerl :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) - :param poke_interval: Polling period in seconds to check for the status of the job. (default: 60) - :param max_retries: Number of times before returning the current state (default: 20) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 10) + :param max_retries: Number of times before returning the current state (default: 60) :param aws_conn_id: The Airflow connection used for AWS credentials. If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or @@ -63,7 +67,6 @@ class OpenSearchServerlessCollectionActiveSensor(AwsBaseSensor[OpenSearchServerl ) SUCCESS_STATES = ("ACTIVE",) FAILURE_MESSAGE = "OpenSearch Serverless Collection sensor failed" - INVALID_ARGS_MESSAGE = "Either collection_ids or collection_names must be provided, not both." aws_hook_class = OpenSearchServerlessHook template_fields: Sequence[str] = aws_template_fields( @@ -78,12 +81,13 @@ def __init__( collection_id: str | None = None, collection_name: str | None = None, poke_interval: int = 10, - max_retries: int = 120, + max_retries: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ) -> None: super().__init__(**kwargs) if not exactly_one(collection_id is None, collection_name is None): - raise AttributeError(self.INVALID_ARGS_MESSAGE) + raise AttributeError("Either collection_ids or collection_names must be provided, not both.") self.collection_id = collection_id self.collection_name = collection_name @@ -94,10 +98,10 @@ def __init__( ) self.poke_interval = poke_interval self.max_retries = max_retries + self.deferrable = deferrable - def poke(self, context: Context) -> bool: - collections = self.hook.conn.batch_get_collection(**self.call_args) - state = collections["collectionDetails"][0]["status"] + def poke(self, context: Context, **kwargs) -> bool: + state = self.hook.conn.batch_get_collection(**self.call_args)["collectionDetails"][0]["status"] if state in self.FAILURE_STATES: # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 @@ -108,3 +112,18 @@ def poke(self, context: Context) -> bool: if state in self.INTERMEDIATE_STATES: return False return True + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=OpenSearchServerlessCollectionActiveTrigger( + collection_id=self.collection_id, + collection_name=self.collection_name, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="poke", + ) + else: + super().execute(context=context) diff --git a/airflow/providers/amazon/aws/triggers/opensearch_serverless.py b/airflow/providers/amazon/aws/triggers/opensearch_serverless.py index dc37c916de723a..35547a49acc7de 100644 --- a/airflow/providers/amazon/aws/triggers/opensearch_serverless.py +++ b/airflow/providers/amazon/aws/triggers/opensearch_serverless.py @@ -19,11 +19,8 @@ from typing import TYPE_CHECKING from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook -from airflow.providers.amazon.aws.sensors.opensearch_serverless import ( - OpenSearchServerlessCollectionActiveSensor, -) from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -from airflow.utils.helpers import exactly_one, prune_dict +from airflow.utils.helpers import exactly_one if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook @@ -51,19 +48,17 @@ def __init__( aws_conn_id: str | None = None, ) -> None: if not exactly_one(collection_id is None, collection_name is None): - raise AttributeError(OpenSearchServerlessCollectionActiveSensor.INVALID_ARGS_MESSAGE) - - call_args = prune_dict({"ids": [collection_id], "names": [collection_name]}) + raise AttributeError("Either collection_ids or collection_names must be provided, not both.") super().__init__( serialized_fields={"collection_id": collection_id, "collection_name": collection_name}, waiter_name="collection_available", - waiter_args=call_args, + waiter_args={"ids": [collection_id]} if collection_id else {"names": [collection_name]}, failure_message="OpenSearch Serverless Collection creation failed.", status_message="Status of OpenSearch Serverless Collection is", status_queries=["status"], - return_key="collection_id" if "ids" in call_args.keys() else "collection_name", - return_value=collection_id if "ids" in call_args.keys() else collection_name, + return_key="collection_id" if collection_id else "collection_name", + return_value=collection_id if collection_name else collection_name, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id,