diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index 96636eb952aa0a..b7c43be67ed850 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -57,3 +57,23 @@ class BedrockRuntimeHook(AwsBaseHook): def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = self.client_type super().__init__(*args, **kwargs) + + +class BedrockAgentHook(AwsBaseHook): + """ + Interact with the Amazon Agents for Bedrock API. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-agent") `. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + client_type = "bedrock-agent" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = self.client_type + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/opensearch_serverless.py b/airflow/providers/amazon/aws/hooks/opensearch_serverless.py new file mode 100644 index 00000000000000..f21d60300671e2 --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/opensearch_serverless.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class OpenSearchServerlessHook(AwsBaseHook): + """ + Interact with the Amazon OpenSearch Serverless API. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("opensearchserverless") `. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + client_type = "opensearchserverless" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = self.client_type + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index bcb0840ecfbdfe..807adaf3938cd1 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -17,16 +17,19 @@ from __future__ import annotations import json +from time import sleep from typing import TYPE_CHECKING, Any, Sequence from botocore.exceptions import ClientError from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook, BedrockRuntimeHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.bedrock import ( BedrockCustomizeModelCompletedTrigger, + BedrockIngestionJobTrigger, + BedrockKnowledgeBaseActiveTrigger, BedrockProvisionModelThroughputCompletedTrigger, ) from airflow.providers.amazon.aws.utils import validate_execute_complete_event @@ -351,3 +354,313 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None self.log.info("Bedrock provisioned throughput job `%s` complete.", event["provisioned_model_id"]) return event["provisioned_model_id"] + + +class BedrockCreateKnowledgeBaseOperator(AwsBaseOperator[BedrockAgentHook]): + """ + Create a knowledge base that contains data sources used by Amazon Bedrock LLMs and Agents. + + To create a knowledge base, you must first set up your data sources and configure a supported vector store. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockCreateKnowledgeBaseOperator` + + :param name: The name of the knowledge base. (templated) + :param embedding_model_arn: ARN of the model used to create vector embeddings for the knowledge base. (templated) + :param role_arn: The ARN of the IAM role with permissions to create the knowledge base. (templated) + :param storage_config: Configuration details of the vector database used for the knowledge base. (templated) + :param wait_for_indexing: Vector indexing can take some time and there is no apparent way to check the state + before trying to create the Knowledge Base. If this is True, and creation fails due to the index not + being available, the operator will wait and retry. (default: True) (templated) + :param indexing_error_retry_delay: Seconds between retries if an index error is encountered. (default 5) (templated) + :param indexing_error_max_attempts: Maximum number of times to retry when encountering an index error. (default 20) (templated) + :param create_knowledge_base_kwargs: Any additional optional parameters to pass to the API call. (templated) + + :param wait_for_completion: Whether to wait for cluster to stop. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 60) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20) + :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = BedrockAgentHook + template_fields: Sequence[str] = aws_template_fields( + "name", + "embedding_model_arn", + "role_arn", + "storage_config", + "wait_for_indexing", + "indexing_error_retry_delay", + "indexing_error_max_attempts", + "create_knowledge_base_kwargs", + ) + + def __init__( + self, + name: str, + embedding_model_arn: str, + role_arn: str, + storage_config: dict[str, Any], + create_knowledge_base_kwargs: dict[str, Any] | None = None, + wait_for_indexing: bool = True, + indexing_error_retry_delay: int = 5, # seconds + indexing_error_max_attempts: int = 20, + wait_for_completion: bool = True, + waiter_delay: int = 60, + waiter_max_attempts: int = 20, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.name = name + self.role_arn = role_arn + self.storage_config = storage_config + self.create_knowledge_base_kwargs = create_knowledge_base_kwargs or {} + self.embedding_model_arn = embedding_model_arn + self.knowledge_base_config = { + "type": "VECTOR", + "vectorKnowledgeBaseConfiguration": {"embeddingModelArn": self.embedding_model_arn}, + } + self.wait_for_indexing = wait_for_indexing + self.indexing_error_retry_delay = indexing_error_retry_delay + self.indexing_error_max_attempts = indexing_error_max_attempts + + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + + self.log.info("Bedrock knowledge base creation job `%s` complete.", self.name) + return event["knowledge_base_id"] + + def execute(self, context: Context) -> str: + def _create_kb(): + # This API call will return the following if the index has not completed, but there is no apparent + # way to check the state of the index beforehand, so retry on index failure if set to do so. + # botocore.errorfactory.ValidationException: An error occurred (ValidationException) + # when calling the CreateKnowledgeBase operation: The knowledge base storage configuration + # provided is invalid... no such index [bedrock-sample-rag-index-abc108] + try: + return self.hook.conn.create_knowledge_base( + name=self.name, + roleArn=self.role_arn, + knowledgeBaseConfiguration=self.knowledge_base_config, + storageConfiguration=self.storage_config, + **self.create_knowledge_base_kwargs, + )["knowledgeBase"]["knowledgeBaseId"] + except ClientError as error: + if all( + [ + error.response["Error"]["Code"] == "ValidationException", + "no such index" in error.response["Error"]["Message"], + self.wait_for_indexing, + self.indexing_error_max_attempts > 0, + ] + ): + self.indexing_error_max_attempts -= 1 + self.log.warning( + "Vector index not ready, retrying in %s seconds.", self.indexing_error_retry_delay + ) + self.log.debug("%s retries remaining.", self.indexing_error_max_attempts) + sleep(self.indexing_error_retry_delay) + return _create_kb() + raise + + self.log.info("Creating Amazon Bedrock Knowledge Base %s", self.name) + knowledge_base_id = _create_kb() + + if self.deferrable: + self.log.info("Deferring for Knowledge base creation.") + self.defer( + trigger=BedrockKnowledgeBaseActiveTrigger( + knowledge_base_id=knowledge_base_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: + self.log.info("Waiting for Knowledge Base creation.") + self.hook.get_waiter("knowledge_base_active").wait( + knowledgeBaseId=knowledge_base_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return knowledge_base_id + + +class BedrockCreateDataSourceOperator(AwsBaseOperator[BedrockAgentHook]): + """ + Set up an Amazon Bedrock Data Source to be added to an Amazon Bedrock Knowledge Base. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockCreateDataSourceOperator` + + :param name: name for the Amazon Bedrock Data Source being created. (templated). + :param bucket_name: The name of the Amazon S3 bucket to use for data source storage. (templated) + :param knowledge_base_id: The unique identifier of the knowledge base to which to add the data source. (templated) + :param create_data_source_kwargs: Any additional optional parameters to pass to the API call. (templated) + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = BedrockAgentHook + template_fields: Sequence[str] = aws_template_fields( + "name", + "bucket_name", + "knowledge_base_id", + "create_data_source_kwargs", + ) + + def __init__( + self, + name: str, + knowledge_base_id: str, + bucket_name: str | None = None, + create_data_source_kwargs: dict[str, Any] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.name = name + self.knowledge_base_id = knowledge_base_id + self.bucket_name = bucket_name + self.create_data_source_kwargs = create_data_source_kwargs or {} + + def execute(self, context: Context) -> str: + create_ds_response = self.hook.conn.create_data_source( + name=self.name, + knowledgeBaseId=self.knowledge_base_id, + dataSourceConfiguration={ + "type": "S3", + "s3Configuration": {"bucketArn": f"arn:aws:s3:::{self.bucket_name}"}, + }, + **self.create_data_source_kwargs, + ) + + return create_ds_response["dataSource"]["dataSourceId"] + + +class BedrockIngestDataOperator(AwsBaseOperator[BedrockAgentHook]): + """ + Begin an ingestion job, in which an Amazon Bedrock data source is added to an Amazon Bedrock knowledge base. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockIngestDataOperator` + + :param knowledge_base_id: The unique identifier of the knowledge base to which to add the data source. (templated) + :param data_source_id: The unique identifier of the data source to ingest. (templated) + :param ingest_data_kwargs: Any additional optional parameters to pass to the API call. (templated) + + :param wait_for_completion: Whether to wait for cluster to stop. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 60) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 10) + :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = BedrockAgentHook + template_fields: Sequence[str] = aws_template_fields( + "knowledge_base_id", + "data_source_id", + "ingest_data_kwargs", + ) + + def __init__( + self, + knowledge_base_id: str, + data_source_id: str, + ingest_data_kwargs: dict[str, Any] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 60, + waiter_max_attempts: int = 10, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.knowledge_base_id = knowledge_base_id + self.data_source_id = data_source_id + self.ingest_data_kwargs = ingest_data_kwargs or {} + + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + + if event["status"] != "success": + raise AirflowException(f"Error while running ingestion job: {event}") + + self.log.info("Bedrock ingestion job `%s` complete.", event["ingestion_job_id"]) + + return event["ingestion_job_id"] + + def execute(self, context: Context) -> str: + ingestion_job_id = self.hook.conn.start_ingestion_job( + knowledgeBaseId=self.knowledge_base_id, dataSourceId=self.data_source_id + )["ingestionJob"]["ingestionJobId"] + + if self.deferrable: + self.log.info("Deferring for ingestion job.") + self.defer( + trigger=BedrockIngestionJobTrigger( + knowledge_base_id=self.knowledge_base_id, + data_source_id=self.data_source_id, + ingestion_job_id=ingestion_job_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: + self.log.info("Waiting for ingestion job %s", ingestion_job_id) + self.hook.get_waiter(waiter_name="ingestion_job_complete").wait( + knowledgeBaseId=self.knowledge_base_id, + dataSourceId=self.data_source_id, + ingestionJobId=ingestion_job_id, + ) + + return ingestion_job_id diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 533a8cab52b488..85328865548680 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -18,14 +18,16 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence, TypeVar from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.bedrock import ( BedrockCustomizeModelCompletedTrigger, + BedrockIngestionJobTrigger, + BedrockKnowledgeBaseActiveTrigger, BedrockProvisionModelThroughputCompletedTrigger, ) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields @@ -34,7 +36,10 @@ from airflow.utils.context import Context -class BedrockBaseSensor(AwsBaseSensor[BedrockHook]): +_GenericBedrockHook = TypeVar("_GenericBedrockHook", BedrockAgentHook, BedrockHook) + + +class BedrockBaseSensor(AwsBaseSensor[_GenericBedrockHook]): """ General sensor behavior for Amazon Bedrock. @@ -57,7 +62,7 @@ class BedrockBaseSensor(AwsBaseSensor[BedrockHook]): SUCCESS_STATES: tuple[str, ...] = () FAILURE_MESSAGE = "" - aws_hook_class = BedrockHook + aws_hook_class: type[_GenericBedrockHook] ui_color = "#66c3ff" def __init__( @@ -68,7 +73,7 @@ def __init__( super().__init__(**kwargs) self.deferrable = deferrable - def poke(self, context: Context) -> bool: + def poke(self, context: Context, **kwargs) -> bool: state = self.get_state() if state in self.FAILURE_STATES: # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 @@ -83,7 +88,7 @@ def get_state(self) -> str: """Implement in subclasses.""" -class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor): +class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor[BedrockHook]): """ Poll the state of the model customization job until it reaches a terminal state; fails if the job fails. @@ -115,6 +120,8 @@ class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor): SUCCESS_STATES: tuple[str, ...] = ("Completed",) FAILURE_MESSAGE = "Bedrock model customization job sensor failed." + aws_hook_class = BedrockHook + template_fields: Sequence[str] = aws_template_fields("job_name") def __init__( @@ -148,7 +155,7 @@ def get_state(self) -> str: return self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"] -class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor): +class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor[BedrockHook]): """ Poll the provisioned model throughput job until it reaches a terminal state; fails if the job fails. @@ -180,6 +187,8 @@ class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor): SUCCESS_STATES: tuple[str, ...] = ("InService",) FAILURE_MESSAGE = "Bedrock provision model throughput sensor failed." + aws_hook_class = BedrockHook + template_fields: Sequence[str] = aws_template_fields("model_id") def __init__( @@ -211,3 +220,153 @@ def execute(self, context: Context) -> Any: ) else: super().execute(context=context) + + +class BedrockKnowledgeBaseActiveSensor(BedrockBaseSensor[BedrockAgentHook]): + """ + Poll the Knowledge Base status until it reaches a terminal state; fails if creation fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:BedrockKnowledgeBaseActiveSensor` + + :param knowledge_base_id: The unique identifier of the knowledge base for which to get information. (templated) + + :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 5) + :param max_retries: Number of times before returning the current state (default: 24) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES: tuple[str, ...] = ("CREATING", "UPDATING") + FAILURE_STATES: tuple[str, ...] = ("DELETING", "FAILED") + SUCCESS_STATES: tuple[str, ...] = ("ACTIVE",) + FAILURE_MESSAGE = "Bedrock Knowledge Base Active sensor failed." + + aws_hook_class = BedrockAgentHook + + template_fields: Sequence[str] = aws_template_fields("knowledge_base_id") + + def __init__( + self, + *, + knowledge_base_id: str, + poke_interval: int = 5, + max_retries: int = 24, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.poke_interval = poke_interval + self.max_retries = max_retries + self.knowledge_base_id = knowledge_base_id + + def get_state(self) -> str: + return self.hook.conn.get_knowledge_base(knowledgeBaseId=self.knowledge_base_id)["knowledgeBase"][ + "status" + ] + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=BedrockKnowledgeBaseActiveTrigger( + knowledge_base_id=self.knowledge_base_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="poke", + ) + else: + super().execute(context=context) + + +class BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]): + """ + Poll the ingestion job status until it reaches a terminal state; fails if creation fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:BedrockIngestionJobSensor` + + :param knowledge_base_id: The unique identifier of the knowledge base for which to get information. (templated) + :param data_source_id: The unique identifier of the data source in the ingestion job. (templated) + :param ingestion_job_id: The unique identifier of the ingestion job. (templated) + + :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 60) + :param max_retries: Number of times before returning the current state (default: 10) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES: tuple[str, ...] = ("STARTING", "IN_PROGRESS") + FAILURE_STATES: tuple[str, ...] = ("FAILED",) + SUCCESS_STATES: tuple[str, ...] = ("COMPLETE",) + FAILURE_MESSAGE = "Bedrock ingestion job sensor failed." + + aws_hook_class = BedrockAgentHook + + template_fields: Sequence[str] = aws_template_fields( + "knowledge_base_id", "data_source_id", "ingestion_job_id" + ) + + def __init__( + self, + *, + knowledge_base_id: str, + data_source_id: str, + ingestion_job_id: str, + poke_interval: int = 60, + max_retries: int = 10, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.poke_interval = poke_interval + self.max_retries = max_retries + self.knowledge_base_id = knowledge_base_id + self.data_source_id = data_source_id + self.ingestion_job_id = ingestion_job_id + + def get_state(self) -> str: + return self.hook.conn.get_ingestion_job( + knowledgeBaseId=self.knowledge_base_id, + ingestionJobId=self.ingestion_job_id, + dataSourceId=self.data_source_id, + )["ingestionJob"]["status"] + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=BedrockIngestionJobTrigger( + knowledge_base_id=self.knowledge_base_id, + ingestion_job_id=self.ingestion_job_id, + data_source_id=self.data_source_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="poke", + ) + else: + super().execute(context=context) diff --git a/airflow/providers/amazon/aws/sensors/opensearch_serverless.py b/airflow/providers/amazon/aws/sensors/opensearch_serverless.py new file mode 100644 index 00000000000000..7f5f650e0ee073 --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/opensearch_serverless.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.configuration import conf +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.triggers.opensearch_serverless import ( + OpenSearchServerlessCollectionActiveTrigger, +) +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.utils.helpers import exactly_one + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class OpenSearchServerlessCollectionActiveSensor(AwsBaseSensor[OpenSearchServerlessHook]): + """ + Poll the state of the Collection until it reaches a terminal state; fails if the query fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:OpenSearchServerlessCollectionAvailableSensor` + + :param collection_id: A collection ID. You can't provide a name and an ID in the same request. + :param collection_name: A collection name. You can't provide a name and an ID in the same request. + + :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 10) + :param max_retries: Number of times before returning the current state (default: 60) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES = ("CREATING",) + FAILURE_STATES = ( + "DELETING", + "FAILED", + ) + SUCCESS_STATES = ("ACTIVE",) + FAILURE_MESSAGE = "OpenSearch Serverless Collection sensor failed" + + aws_hook_class = OpenSearchServerlessHook + template_fields: Sequence[str] = aws_template_fields( + "collection_id", + "collection_name", + ) + ui_color = "#66c3ff" + + def __init__( + self, + *, + collection_id: str | None = None, + collection_name: str | None = None, + poke_interval: int = 10, + max_retries: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if not exactly_one(collection_id is None, collection_name is None): + raise AttributeError("Either collection_ids or collection_names must be provided, not both.") + self.collection_id = collection_id + self.collection_name = collection_name + + self.poke_interval = poke_interval + self.max_retries = max_retries + self.deferrable = deferrable + + def poke(self, context: Context, **kwargs) -> bool: + call_args = ( + {"ids": [str(self.collection_id)]} + if self.collection_id + else {"names": [str(self.collection_name)]} + ) + state = self.hook.conn.batch_get_collection(**call_args)["collectionDetails"][0]["status"] + + if state in self.FAILURE_STATES: + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(self.FAILURE_MESSAGE) + raise AirflowException(self.FAILURE_MESSAGE) + + if state in self.INTERMEDIATE_STATES: + return False + return True + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=OpenSearchServerlessCollectionActiveTrigger( + collection_id=self.collection_id, + collection_name=self.collection_name, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="poke", + ) + else: + super().execute(context=context) diff --git a/airflow/providers/amazon/aws/triggers/bedrock.py b/airflow/providers/amazon/aws/triggers/bedrock.py index cee4f6cee782cb..99d632d26e0c3d 100644 --- a/airflow/providers/amazon/aws/triggers/bedrock.py +++ b/airflow/providers/amazon/aws/triggers/bedrock.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger if TYPE_CHECKING: @@ -61,11 +61,49 @@ def hook(self) -> AwsGenericHook: return BedrockHook(aws_conn_id=self.aws_conn_id) +class BedrockKnowledgeBaseActiveTrigger(AwsBaseWaiterTrigger): + """ + Trigger when a Bedrock Knowledge Base reaches the ACTIVE state. + + :param knowledge_base_id: The unique identifier of the knowledge base for which to get information. + + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 5) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 24) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + knowledge_base_id: str, + waiter_delay: int = 5, + waiter_max_attempts: int = 24, + aws_conn_id: str | None = None, + ) -> None: + super().__init__( + serialized_fields={"knowledge_base_id": knowledge_base_id}, + waiter_name="knowledge_base_active", + waiter_args={"knowledgeBaseId": knowledge_base_id}, + failure_message="Bedrock Knowledge Base creation failed.", + status_message="Status of Bedrock Knowledge Base job is", + status_queries=["status"], + return_key="knowledge_base_id", + return_value=knowledge_base_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return BedrockAgentHook(aws_conn_id=self.aws_conn_id) + + class BedrockProvisionModelThroughputCompletedTrigger(AwsBaseWaiterTrigger): """ Trigger when a provisioned throughput job is complete. :param provisioned_model_id: The ARN or name of the provisioned throughput. + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) :param aws_conn_id: The Airflow connection used for AWS credentials. @@ -95,3 +133,52 @@ def __init__( def hook(self) -> AwsGenericHook: return BedrockHook(aws_conn_id=self.aws_conn_id) + + +class BedrockIngestionJobTrigger(AwsBaseWaiterTrigger): + """ + Trigger when a Bedrock ingestion job reaches the COMPLETE state. + + :param knowledge_base_id: The unique identifier of the knowledge base for which to get information. + :param data_source_id: The unique identifier of the data source in the ingestion job. + :param ingestion_job_id: The unique identifier of the ingestion job. + + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 10) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + knowledge_base_id: str, + data_source_id: str, + ingestion_job_id: str, + waiter_delay: int = 60, + waiter_max_attempts: int = 10, + aws_conn_id: str | None = None, + ) -> None: + super().__init__( + serialized_fields={ + "knowledge_base_id": knowledge_base_id, + "data_source_id": data_source_id, + "ingestion_job_id": ingestion_job_id, + }, + waiter_name="ingestion_job_complete", + waiter_args={ + "knowledgeBaseId": knowledge_base_id, + "dataSourceId": data_source_id, + "ingestionJobId": ingestion_job_id, + }, + failure_message="Bedrock ingestion job creation failed.", + status_message="Status of Bedrock ingestion job is", + status_queries=["status"], + return_key="ingestion_job_id", + return_value=ingestion_job_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return BedrockAgentHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/triggers/opensearch_serverless.py b/airflow/providers/amazon/aws/triggers/opensearch_serverless.py new file mode 100644 index 00000000000000..bf94c83bf5f8a9 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/opensearch_serverless.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger +from airflow.utils.helpers import exactly_one + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + +class OpenSearchServerlessCollectionActiveTrigger(AwsBaseWaiterTrigger): + """ + Trigger when an Amazon OpenSearch Serverless Collection reaches the ACTIVE state. + + :param collection_id: A collection ID. You can't provide a name and an ID in the same request. + :param collection_name: A collection name. You can't provide a name and an ID in the same request. + + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 20) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + collection_id: str | None = None, + collection_name: str | None = None, + waiter_delay: int = 60, + waiter_max_attempts: int = 20, + aws_conn_id: str | None = None, + ) -> None: + if not exactly_one(collection_id is None, collection_name is None): + raise AttributeError("Either collection_ids or collection_names must be provided, not both.") + + super().__init__( + serialized_fields={"collection_id": collection_id, "collection_name": collection_name}, + waiter_name="collection_available", + waiter_args={"ids": [collection_id]} if collection_id else {"names": [collection_name]}, + failure_message="OpenSearch Serverless Collection creation failed.", + status_message="Status of OpenSearch Serverless Collection is", + status_queries=["status"], + return_key="collection_id" if collection_id else "collection_name", + return_value=collection_id if collection_id else collection_name, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return OpenSearchServerlessHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/waiters/bedrock-agent.json b/airflow/providers/amazon/aws/waiters/bedrock-agent.json new file mode 100644 index 00000000000000..c59e84bfda74bb --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/bedrock-agent.json @@ -0,0 +1,73 @@ +{ + "version": 2, + "waiters": { + "knowledge_base_active": { + "delay": 5, + "maxAttempts": 24, + "operation": "getKnowledgeBase", + "acceptors": [ + { + "matcher": "path", + "argument": "knowledgeBase.status", + "expected": "ACTIVE", + "state": "success" + }, + { + "matcher": "path", + "argument": "knowledgeBase.status", + "expected": "CREATING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "knowledgeBase.status", + "expected": "DELETING", + "state": "failure" + }, + { + "matcher": "path", + "argument": "knowledgeBase.status", + "expected": "UPDATING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "knowledgeBase.status", + "expected": "FAILED", + "state": "failure" + } + ] + }, + "ingestion_job_complete": { + "delay": 60, + "maxAttempts": 10, + "operation": "getIngestionJob", + "acceptors": [ + { + "matcher": "path", + "argument": "ingestionJob.status", + "expected": "COMPLETE", + "state": "success" + }, + { + "matcher": "path", + "argument": "ingestionJob.status", + "expected": "STARTING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "ingestionJob.status", + "expected": "IN_PROGRESS", + "state": "retry" + }, + { + "matcher": "path", + "argument": "ingestionJob.status", + "expected": "FAILED", + "state": "failure" + } + ] + } + } +} diff --git a/airflow/providers/amazon/aws/waiters/opensearchserverless.json b/airflow/providers/amazon/aws/waiters/opensearchserverless.json new file mode 100644 index 00000000000000..b21fa318b330eb --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/opensearchserverless.json @@ -0,0 +1,36 @@ +{ + "version": 2, + "waiters": { + "collection_available": { + "operation": "BatchGetCollection", + "delay": 10, + "maxAttempts": 120, + "acceptors": [ + { + "matcher": "path", + "argument": "collectionDetails[0].status", + "expected": "ACTIVE", + "state": "success" + }, + { + "matcher": "path", + "argument": "collectionDetails[0].status", + "expected": "DELETING", + "state": "failure" + }, + { + "matcher": "path", + "argument": "collectionDetails[0].status", + "expected": "CREATING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "collectionDetails[0].status", + "expected": "FAILED", + "state": "failure" + } + ] + } + } +} diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 6973bf7c63121f..97fdb9fd8f65e0 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -238,6 +238,12 @@ integrations: external-doc-url: https://aws.amazon.com/kinesis/data-firehose/ logo: /integration-logos/aws/Amazon-Kinesis-Data-Firehose_light-bg@4x.png tags: [aws] + - integration-name: Amazon OpenSearch Serverless + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/opensearchserverless.rst + external-doc-url: https://aws.amazon.com/opensearchserverless/ + logo: /integration-logos/aws/Amazon-OpenSearch_light-bg@4x.png + tags: [aws] - integration-name: Amazon RDS external-doc-url: https://aws.amazon.com/rds/ logo: /integration-logos/aws/Amazon-RDS_light-bg@4x.png @@ -499,6 +505,9 @@ sensors: - integration-name: AWS Lambda python-modules: - airflow.providers.amazon.aws.sensors.lambda_function + - integration-name: Amazon OpenSearch Serverless + python-modules: + - airflow.providers.amazon.aws.sensors.opensearch_serverless - integration-name: Amazon RDS python-modules: - airflow.providers.amazon.aws.sensors.rds @@ -599,6 +608,9 @@ hooks: - integration-name: Amazon CloudWatch Logs python-modules: - airflow.providers.amazon.aws.hooks.logs + - integration-name: Amazon OpenSearch Serverless + python-modules: + - airflow.providers.amazon.aws.hooks.opensearch_serverless - integration-name: Amazon RDS python-modules: - airflow.providers.amazon.aws.hooks.rds @@ -669,6 +681,9 @@ triggers: - integration-name: AWS Lambda python-modules: - airflow.providers.amazon.aws.triggers.lambda_function + - integration-name: Amazon OpenSearch Serverless + python-modules: + - airflow.providers.amazon.aws.triggers.opensearch_serverless - integration-name: Amazon Redshift python-modules: - airflow.providers.amazon.aws.triggers.redshift_cluster diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst index 4fbe8b7f1a03e6..d5074c319d5de8 100644 --- a/docs/apache-airflow-providers-amazon/operators/bedrock.rst +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -108,6 +108,73 @@ Trigger. :start-after: [START howto_operator_provision_throughput] :end-before: [END howto_operator_provision_throughput] +.. _howto/operator:BedrockCreateKnowledgeBaseOperator: + +Create an Amazon Bedrock Knowledge Base +======================================== + +To create an Amazon Bedrock Knowledge Base, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCreateKnowledgeBaseOperator`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bedrock_create_knowledge_base] + :end-before: [END howto_operator_bedrock_create_knowledge_base] + +.. _howto/operator:BedrockDeleteKnowledgeBase: + +Delete an Amazon Bedrock Knowledge Base +======================================= + +Deleting a Knowledge Base is a simple boto API call and can be done in a TaskFlow task like the example below. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :start-after: [START howto_operator_bedrock_delete_knowledge_base] + :end-before: [END howto_operator_bedrock_delete_knowledge_base] + +.. _howto/operator:BedrockCreateDataSourceOperator: + +Create an Amazon Bedrock Data Source +==================================== + +To create an Amazon Bedrock Data Source, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCreateDataSourceOperator`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bedrock_create_data_source] + :end-before: [END howto_operator_bedrock_create_data_source] + +.. _howto_operator:BedrockDeleteDataSource: + +Delete an Amazon Bedrock Data Source +==================================== + +Deleting a Data Source is a simple boto API call and can be done in a TaskFlow task like the example below. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :start-after: [START howto_operator_bedrock_delete_data_source] + :end-before: [END howto_operator_bedrock_delete_data_source] + +.. _howto/operator:BedrockIngestDataOperator: + +Ingest data into an Amazon Bedrock Data Source +============================================== + +To add data from an Amazon S3 bucket into an Amazon Bedrock Data Source, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockIngestDataOperator`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bedrock_ingest_data] + :end-before: [END howto_operator_bedrock_ingest_data] + + Sensors ------- @@ -141,7 +208,37 @@ terminal state you can use :start-after: [START howto_sensor_provision_throughput] :end-before: [END howto_sensor_provision_throughput] +.. _howto/sensor:BedrockKnowledgeBaseActiveSensor: + +Wait for an Amazon Bedrock Knowledge Base +========================================= + +To wait on the state of an Amazon Bedrock Knowledge Base until it reaches a terminal state you can use +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockKnowledgeBaseActiveSensor` + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_bedrock_knowledge_base_active] + :end-before: [END howto_sensor_bedrock_knowledge_base_active] + +.. _howto/sensor:BedrockIngestionJobSensor: + +Wait for an Amazon Bedrock ingestion job to finish +================================================== + +To wait on the state of an Amazon Bedrock data ingestion job until it reaches a terminal state you can use +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockIngestionJobSensor` + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_bedrock_ingest_data] + :end-before: [END howto_sensor_bedrock_ingest_data] + Reference --------- * `AWS boto3 library documentation for Amazon Bedrock `__ +* `AWS boto3 library documentation for Amazon Bedrock Runtime `__ +* `AWS boto3 library documentation for Amazon Bedrock Agents `__ diff --git a/docs/apache-airflow-providers-amazon/operators/opensearchserverless.rst b/docs/apache-airflow-providers-amazon/operators/opensearchserverless.rst new file mode 100644 index 00000000000000..f947c91e307791 --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/opensearchserverless.rst @@ -0,0 +1,59 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +============================ +Amazon OpenSearch Serverless +============================ + +`Amazon OpenSearch Serverless `__ is an +on-demand, auto-scaling configuration for Amazon OpenSearch Service. An OpenSearch +Serverless collection is an OpenSearch cluster that scales compute capacity based on +your application's needs. This contrasts with OpenSearch Service provisioned OpenSearch +domains, which you manually manage capacity for. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Sensors +------- + +.. _howto/sensor:OpenSearchServerlessCollectionAvailableSensor: + +Wait for an Amazon OpenSearch Serverless Collection to become active +==================================================================== + +To wait on the state of an Amazon Bedrock customize model job until it reaches a terminal state you can use +:class:`~airflow.providers.amazon.aws.sensors.bedrock.OpenSearchServerlessCollectionActiveSensor` + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_opensearch_collection_active] + :end-before: [END howto_sensor_opensearch_collection_active] + +Reference +--------- + +* `AWS boto3 library documentation for Amazon OpenSearch Service `__ +* `AWS boto3 library documentation for Amazon OpenSearch Serverless `__ diff --git a/docs/integration-logos/aws/Amazon-OpenSearch_light-bg@4x.png b/docs/integration-logos/aws/Amazon-OpenSearch_light-bg@4x.png new file mode 100644 index 00000000000000..c5ce09d7a16c6f Binary files /dev/null and b/docs/integration-logos/aws/Amazon-OpenSearch_light-bg@4x.png differ diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index 43b467d549a71c..c47b46e9a86474 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -18,7 +18,7 @@ import pytest -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook, BedrockRuntimeHook class TestBedrockHooks: @@ -27,6 +27,7 @@ class TestBedrockHooks: [ pytest.param(BedrockHook(), "bedrock", id="bedrock"), pytest.param(BedrockRuntimeHook(), "bedrock-runtime", id="bedrock-runtime"), + pytest.param(BedrockAgentHook(), "bedrock-agent", id="bedrock-agent"), ], ) def test_bedrock_hooks(self, test_hook, service_name): diff --git a/tests/providers/amazon/aws/hooks/test_opensearch_serverless.py b/tests/providers/amazon/aws/hooks/test_opensearch_serverless.py new file mode 100644 index 00000000000000..19ed07b40d3d7d --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_opensearch_serverless.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook + + +class TestOpenSearchServerlessHook: + def test_opensearch_serverless_hook(self): + hook = OpenSearchServerlessHook() + service_name = "opensearchserverless" + + assert hook.conn is not None + assert hook.conn.meta.service_model.service_name == service_name diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index 8d7e16361f6e9a..ae255881467972 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -25,10 +25,13 @@ from botocore.exceptions import ClientError from moto import mock_aws -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook, BedrockRuntimeHook from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCreateDataSourceOperator, + BedrockCreateKnowledgeBaseOperator, BedrockCreateProvisionedModelThroughputOperator, BedrockCustomizeModelOperator, + BedrockIngestDataOperator, BedrockInvokeModelOperator, ) @@ -217,3 +220,129 @@ def test_provisioned_model_wait_combinations( assert response == self.MODEL_ARN assert bedrock_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + + +class TestBedrockCreateKnowledgeBaseOperator: + KNOWLEDGE_BASE_ID = "knowledge_base_id" + + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockAgentHook, "conn") as _conn: + _conn.create_knowledge_base.return_value = { + "knowledgeBase": {"knowledgeBaseId": self.KNOWLEDGE_BASE_ID} + } + yield _conn + + @pytest.fixture + def bedrock_hook(self) -> Generator[BedrockAgentHook, None, None]: + with mock_aws(): + hook = BedrockAgentHook() + yield hook + + def setup_method(self): + self.operator = BedrockCreateKnowledgeBaseOperator( + task_id="create_knowledge_base", + name=self.KNOWLEDGE_BASE_ID, + embedding_model_arn="arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-embed-text-v1", + role_arn="role-arn", + storage_config={ + "type": "OPENSEARCH_SERVERLESS", + "opensearchServerlessConfiguration": { + "collectionArn": "collection_arn", + "vectorIndexName": "index_name", + "fieldMapping": { + "vectorField": "vector", + "textField": "text", + "metadataField": "text-metadata", + }, + }, + }, + ) + self.operator.defer = mock.MagicMock() + + @pytest.mark.parametrize( + "wait_for_completion, deferrable", + [ + pytest.param(False, False, id="no_wait"), + pytest.param(True, False, id="wait"), + pytest.param(False, True, id="defer"), + ], + ) + @mock.patch.object(BedrockAgentHook, "get_waiter") + def test_create_knowledge_base_wait_combinations( + self, _, wait_for_completion, deferrable, mock_conn, bedrock_hook + ): + self.operator.wait_for_completion = wait_for_completion + self.operator.deferrable = deferrable + + response = self.operator.execute({}) + + assert response == self.KNOWLEDGE_BASE_ID + assert bedrock_hook.get_waiter.call_count == wait_for_completion + assert self.operator.defer.call_count == deferrable + + def test_returns_id(self, mock_conn): + self.operator.wait_for_completion = False + result = self.operator.execute({}) + + assert result == self.KNOWLEDGE_BASE_ID + + +class TestBedrockCreateDataSourceOperator: + DATA_SOURCE_ID = "data_source_id" + + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockAgentHook, "conn") as _conn: + _conn.create_data_source.return_value = {"dataSource": {"dataSourceId": self.DATA_SOURCE_ID}} + yield _conn + + @pytest.fixture + def bedrock_hook(self) -> Generator[BedrockAgentHook, None, None]: + with mock_aws(): + hook = BedrockAgentHook() + yield hook + + def setup_method(self): + self.operator = BedrockCreateDataSourceOperator( + task_id="create_data_source", + name=self.DATA_SOURCE_ID, + knowledge_base_id="test_knowledge_base_id", + bucket_name="test_bucket", + ) + + def test_id_returned(self, mock_conn): + result = self.operator.execute({}) + + assert result == self.DATA_SOURCE_ID + + +class TestBedrockIngestDataOperator: + INGESTION_JOB_ID = "ingestion_job_id" + + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockAgentHook, "conn") as _conn: + _conn.start_ingestion_job.return_value = { + "ingestionJob": {"ingestionJobId": self.INGESTION_JOB_ID} + } + yield _conn + + @pytest.fixture + def bedrock_hook(self) -> Generator[BedrockAgentHook, None, None]: + with mock_aws(): + hook = BedrockAgentHook() + yield hook + + def setup_method(self): + self.operator = BedrockIngestDataOperator( + task_id="create_data_source", + data_source_id="data_source_id", + knowledge_base_id="knowledge_base_id", + wait_for_completion=False, + ) + + def test_id_returned(self, mock_conn): + result = self.operator.execute({}) + + assert result == self.INGESTION_JOB_ID diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py index 69b80c72f7e7e7..4ed531931608d9 100644 --- a/tests/providers/amazon/aws/sensors/test_bedrock.py +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -22,14 +22,18 @@ import pytest from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook from airflow.providers.amazon.aws.sensors.bedrock import ( BedrockCustomizeModelCompletedSensor, + BedrockIngestionJobSensor, + BedrockKnowledgeBaseActiveSensor, BedrockProvisionModelThroughputCompletedSensor, ) class TestBedrockCustomizeModelCompletedSensor: + SENSOR = BedrockCustomizeModelCompletedSensor + def setup_method(self): self.default_op_kwargs = dict( task_id="test_bedrock_customize_model_sensor", @@ -37,16 +41,16 @@ def setup_method(self): poke_interval=5, max_retries=1, ) - self.sensor = BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs, aws_conn_id=None) + self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) def test_base_aws_op_attributes(self): - op = BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs) + op = self.SENSOR(**self.default_op_kwargs) assert op.hook.aws_conn_id == "aws_default" assert op.hook._region_name is None assert op.hook._verify is None assert op.hook._config is None - op = BedrockCustomizeModelCompletedSensor( + op = self.SENSOR( **self.default_op_kwargs, aws_conn_id="aws-test-custom-conn", region_name="eu-west-1", @@ -59,13 +63,13 @@ def test_base_aws_op_attributes(self): assert op.hook._config is not None assert op.hook._config.read_timeout == 42 - @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.SUCCESS_STATES)) + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) @mock.patch.object(BedrockHook, "conn") def test_poke_success_states(self, mock_conn, state): mock_conn.get_model_customization_job.return_value = {"status": state} assert self.sensor.poke({}) is True - @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.INTERMEDIATE_STATES)) + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) @mock.patch.object(BedrockHook, "conn") def test_poke_intermediate_states(self, mock_conn, state): mock_conn.get_model_customization_job.return_value = {"status": state} @@ -78,18 +82,18 @@ def test_poke_intermediate_states(self, mock_conn, state): pytest.param(True, AirflowSkipException, id="soft-fail"), ], ) - @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.FAILURE_STATES)) + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(BedrockHook, "conn") def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): mock_conn.get_model_customization_job.return_value = {"status": state} - sensor = BedrockCustomizeModelCompletedSensor( - **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail - ) + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): sensor.poke({}) class TestBedrockProvisionModelThroughputCompletedSensor: + SENSOR = BedrockProvisionModelThroughputCompletedSensor + def setup_method(self): self.default_op_kwargs = dict( task_id="test_bedrock_provision_model_sensor", @@ -97,18 +101,16 @@ def setup_method(self): poke_interval=5, max_retries=1, ) - self.sensor = BedrockProvisionModelThroughputCompletedSensor( - **self.default_op_kwargs, aws_conn_id=None - ) + self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) def test_base_aws_op_attributes(self): - op = BedrockProvisionModelThroughputCompletedSensor(**self.default_op_kwargs) + op = self.SENSOR(**self.default_op_kwargs) assert op.hook.aws_conn_id == "aws_default" assert op.hook._region_name is None assert op.hook._verify is None assert op.hook._config is None - op = BedrockProvisionModelThroughputCompletedSensor( + op = self.SENSOR( **self.default_op_kwargs, aws_conn_id="aws-test-custom-conn", region_name="eu-west-1", @@ -121,15 +123,13 @@ def test_base_aws_op_attributes(self): assert op.hook._config is not None assert op.hook._config.read_timeout == 42 - @pytest.mark.parametrize("state", list(BedrockProvisionModelThroughputCompletedSensor.SUCCESS_STATES)) + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) @mock.patch.object(BedrockHook, "conn") def test_poke_success_states(self, mock_conn, state): mock_conn.get_provisioned_model_throughput.return_value = {"status": state} assert self.sensor.poke({}) is True - @pytest.mark.parametrize( - "state", list(BedrockProvisionModelThroughputCompletedSensor.INTERMEDIATE_STATES) - ) + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) @mock.patch.object(BedrockHook, "conn") def test_poke_intermediate_states(self, mock_conn, state): mock_conn.get_provisioned_model_throughput.return_value = {"status": state} @@ -142,13 +142,133 @@ def test_poke_intermediate_states(self, mock_conn, state): pytest.param(True, AirflowSkipException, id="soft-fail"), ], ) - @pytest.mark.parametrize("state", list(BedrockProvisionModelThroughputCompletedSensor.FAILURE_STATES)) + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(BedrockHook, "conn") def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): mock_conn.get_provisioned_model_throughput.return_value = {"status": state} - sensor = BedrockProvisionModelThroughputCompletedSensor( - **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + + with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor.poke({}) + + +class TestBedrockKnowledgeBaseActiveSensor: + SENSOR = BedrockKnowledgeBaseActiveSensor + + def setup_method(self): + self.default_op_kwargs = dict( + task_id="test_bedrock_knowledge_base_active_sensor", + knowledge_base_id="knowledge_base_id", + poke_interval=5, + max_retries=1, + ) + self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) + + def test_base_aws_op_attributes(self): + op = self.SENSOR(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = self.SENSOR( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) + @mock.patch.object(BedrockAgentHook, "conn") + def test_poke_success_states(self, mock_conn, state): + mock_conn.get_knowledge_base.return_value = {"knowledgeBase": {"status": state}} + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) + @mock.patch.object(BedrockAgentHook, "conn") + def test_poke_intermediate_states(self, mock_conn, state): + mock_conn.get_knowledge_base.return_value = {"knowledgeBase": {"status": state}} + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], + ) + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) + @mock.patch.object(BedrockAgentHook, "conn") + def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + mock_conn.get_knowledge_base.return_value = {"knowledgeBase": {"status": state}} + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor.poke({}) + + +class TestBedrockIngestionJobSensor: + SENSOR = BedrockIngestionJobSensor + + def setup_method(self): + self.default_op_kwargs = dict( + task_id="test_bedrock_knowledge_base_active_sensor", + knowledge_base_id="knowledge_base_id", + data_source_id="data_source_id", + ingestion_job_id="ingestion_job_id", + poke_interval=5, + max_retries=1, + ) + self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) + + def test_base_aws_op_attributes(self): + op = self.SENSOR(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = self.SENSOR( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) + @mock.patch.object(BedrockAgentHook, "conn") + def test_poke_success_states(self, mock_conn, state): + mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}} + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) + @mock.patch.object(BedrockAgentHook, "conn") + def test_poke_intermediate_states(self, mock_conn, state): + mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}} + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], + ) + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) + @mock.patch.object(BedrockAgentHook, "conn") + def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}} + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): sensor.poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_opensearch_serverless.py b/tests/providers/amazon/aws/sensors/test_opensearch_serverless.py new file mode 100644 index 00000000000000..3b7474aeac19d5 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_opensearch_serverless.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook +from airflow.providers.amazon.aws.sensors.opensearch_serverless import ( + OpenSearchServerlessCollectionActiveSensor, +) + + +class TestOpenSearchServerlessCollectionActiveSensor: + def setup_method(self): + self.default_op_kwargs = dict( + task_id="test_sensor", + collection_id="knowledge_base_id", + poke_interval=5, + max_retries=1, + ) + self.sensor = OpenSearchServerlessCollectionActiveSensor(**self.default_op_kwargs, aws_conn_id=None) + + def test_base_aws_op_attributes(self): + op = OpenSearchServerlessCollectionActiveSensor(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = OpenSearchServerlessCollectionActiveSensor( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + @pytest.mark.parametrize( + "collection_name, collection_id, expected_pass", + [ + pytest.param("name", "id", False, id="both_provided_fails"), + pytest.param("name", None, True, id="only_name_provided_passes"), + pytest.param(None, "id", True, id="only_id_provided_passes"), + ], + ) + def test_name_and_id_combinations(self, collection_name, collection_id, expected_pass): + call_args = { + "task_id": "test_sensor", + "collection_name": collection_name, + "collection_id": collection_id, + "poke_interval": 5, + "max_retries": 1, + } + if expected_pass: + op = OpenSearchServerlessCollectionActiveSensor(**call_args) + assert op.collection_id == collection_id + assert op.collection_name == collection_name + if not expected_pass: + with pytest.raises( + AttributeError, match="Either collection_ids or collection_names must be provided, not both." + ): + OpenSearchServerlessCollectionActiveSensor(**call_args) + + @pytest.mark.parametrize("state", list(OpenSearchServerlessCollectionActiveSensor.SUCCESS_STATES)) + @mock.patch.object(OpenSearchServerlessHook, "conn") + def test_poke_success_states(self, mock_conn, state): + mock_conn.batch_get_collection.return_value = {"collectionDetails": [{"status": state}]} + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize("state", list(OpenSearchServerlessCollectionActiveSensor.INTERMEDIATE_STATES)) + @mock.patch.object(OpenSearchServerlessHook, "conn") + def test_poke_intermediate_states(self, mock_conn, state): + mock_conn.batch_get_collection.return_value = {"collectionDetails": [{"status": state}]} + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], + ) + @pytest.mark.parametrize("state", list(OpenSearchServerlessCollectionActiveSensor.FAILURE_STATES)) + @mock.patch.object(OpenSearchServerlessHook, "conn") + def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + mock_conn.batch_get_collection.return_value = {"collectionDetails": [{"status": state}]} + sensor = OpenSearchServerlessCollectionActiveSensor( + **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail + ) + with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor.poke({}) diff --git a/tests/providers/amazon/aws/triggers/test_bedrock.py b/tests/providers/amazon/aws/triggers/test_bedrock.py index 64942db5792a1c..90112d8c1dc1d3 100644 --- a/tests/providers/amazon/aws/triggers/test_bedrock.py +++ b/tests/providers/amazon/aws/triggers/test_bedrock.py @@ -21,17 +21,30 @@ import pytest -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook from airflow.providers.amazon.aws.triggers.bedrock import ( BedrockCustomizeModelCompletedTrigger, + BedrockIngestionJobTrigger, + BedrockKnowledgeBaseActiveTrigger, BedrockProvisionModelThroughputCompletedTrigger, ) from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.bedrock." -class TestBedrockCustomizeModelCompletedTrigger: +class TestBaseBedrockTrigger: + EXPECTED_WAITER_NAME: str | None = None + + def test_setup(self): + # Ensure that all subclasses have an expected waiter name set. + if self.__class__.__name__ != "TestBaseBedrockTrigger": + assert isinstance(self.EXPECTED_WAITER_NAME, str) + + +class TestBedrockCustomizeModelCompletedTrigger(TestBaseBedrockTrigger): + EXPECTED_WAITER_NAME = "model_customization_job_complete" JOB_NAME = "test_job" def test_serialization(self): @@ -53,10 +66,12 @@ async def test_run_success(self, mock_async_conn, mock_get_waiter): response = await generator.asend(None) assert response == TriggerEvent({"status": "success", "job_name": self.JOB_NAME}) - assert mock_get_waiter().wait.call_count == 1 + assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME) + mock_get_waiter().wait.assert_called_once() -class TestBedrockProvisionModelThroughputCompletedTrigger: +class TestBedrockProvisionModelThroughputCompletedTrigger(TestBaseBedrockTrigger): + EXPECTED_WAITER_NAME = "provisioned_model_throughput_complete" PROVISIONED_MODEL_ID = "provisioned_model_id" def test_serialization(self): @@ -84,4 +99,72 @@ async def test_run_success(self, mock_async_conn, mock_get_waiter): assert response == TriggerEvent( {"status": "success", "provisioned_model_id": self.PROVISIONED_MODEL_ID} ) - assert mock_get_waiter().wait.call_count == 1 + assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME) + mock_get_waiter().wait.assert_called_once() + + +class TestBedrockKnowledgeBaseActiveTrigger(TestBaseBedrockTrigger): + EXPECTED_WAITER_NAME = "knowledge_base_active" + KNOWLEDGE_BASE_NAME = "test_kb" + + def test_serialization(self): + """Assert that arguments and classpath are correctly serialized.""" + trigger = BedrockKnowledgeBaseActiveTrigger(knowledge_base_id=self.KNOWLEDGE_BASE_NAME) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "BedrockKnowledgeBaseActiveTrigger" + assert kwargs.get("knowledge_base_id") == self.KNOWLEDGE_BASE_NAME + + @pytest.mark.asyncio + @mock.patch.object(BedrockAgentHook, "get_waiter") + @mock.patch.object(BedrockAgentHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = BedrockKnowledgeBaseActiveTrigger(knowledge_base_id=self.KNOWLEDGE_BASE_NAME) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "knowledge_base_id": self.KNOWLEDGE_BASE_NAME}) + assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME) + mock_get_waiter().wait.assert_called_once() + + +class TestBedrockIngestionJobTrigger(TestBaseBedrockTrigger): + EXPECTED_WAITER_NAME = "ingestion_job_complete" + + KNOWLEDGE_BASE_ID = "test_kb" + DATA_SOURCE_ID = "test_ds" + INGESTION_JOB_ID = "test_ingestion_job" + + def test_serialization(self): + """Assert that arguments and classpath are correctly serialized.""" + trigger = BedrockIngestionJobTrigger( + knowledge_base_id=self.KNOWLEDGE_BASE_ID, + data_source_id=self.DATA_SOURCE_ID, + ingestion_job_id=self.INGESTION_JOB_ID, + ) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "BedrockIngestionJobTrigger" + assert kwargs.get("knowledge_base_id") == self.KNOWLEDGE_BASE_ID + assert kwargs.get("data_source_id") == self.DATA_SOURCE_ID + assert kwargs.get("ingestion_job_id") == self.INGESTION_JOB_ID + + @pytest.mark.asyncio + @mock.patch.object(BedrockAgentHook, "get_waiter") + @mock.patch.object(BedrockAgentHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = BedrockIngestionJobTrigger( + knowledge_base_id=self.KNOWLEDGE_BASE_ID, + data_source_id=self.DATA_SOURCE_ID, + ingestion_job_id=self.INGESTION_JOB_ID, + ) + + generator = trigger.run() + response = await generator.asend(None) + + assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME) + assert response == TriggerEvent({"status": "success", "ingestion_job_id": self.INGESTION_JOB_ID}) + mock_get_waiter().wait.assert_called_once() diff --git a/tests/providers/amazon/aws/triggers/test_opensearch_serverless.py b/tests/providers/amazon/aws/triggers/test_opensearch_serverless.py new file mode 100644 index 00000000000000..c992d6a50da690 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_opensearch_serverless.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook +from airflow.providers.amazon.aws.triggers.opensearch_serverless import ( + OpenSearchServerlessCollectionActiveTrigger, +) +from airflow.triggers.base import TriggerEvent +from airflow.utils.helpers import prune_dict +from tests.providers.amazon.aws.triggers.test_base import TestAwsBaseWaiterTrigger + +BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.opensearch_serverless." + + +class TestBaseBedrockTrigger(TestAwsBaseWaiterTrigger): + EXPECTED_WAITER_NAME: str | None = None + + def test_setup(self): + # Ensure that all subclasses have an expected waiter name set. + if self.__class__.__name__ != "TestBaseBedrockTrigger": + assert isinstance(self.EXPECTED_WAITER_NAME, str) + + +class TestOpenSearchServerlessCollectionActiveTrigger: + EXPECTED_WAITER_NAME = "collection_available" + COLLECTION_NAME = "test_collection_name" + COLLECTION_ID = "test_collection_id" + + @pytest.mark.parametrize( + "collection_name, collection_id, expected_pass", + [ + pytest.param(COLLECTION_NAME, COLLECTION_ID, False, id="both_provided_fails"), + pytest.param(COLLECTION_NAME, None, True, id="only_name_provided_passes"), + pytest.param(None, COLLECTION_ID, True, id="only_id_provided_passes"), + ], + ) + def test_serialization(self, collection_name, collection_id, expected_pass): + """Assert that arguments and classpath are correctly serialized.""" + call_args = prune_dict({"collection_id": collection_id, "collection_name": collection_name}) + + if expected_pass: + trigger = OpenSearchServerlessCollectionActiveTrigger(**call_args) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "OpenSearchServerlessCollectionActiveTrigger" + if call_args.get("collection_name"): + assert kwargs.get("collection_name") == self.COLLECTION_NAME + if call_args.get("collection_id"): + assert kwargs.get("collection_id") == self.COLLECTION_ID + + if not expected_pass: + with pytest.raises( + AttributeError, match="Either collection_ids or collection_names must be provided, not both." + ): + OpenSearchServerlessCollectionActiveTrigger(**call_args) + + @pytest.mark.asyncio + @mock.patch.object(OpenSearchServerlessHook, "get_waiter") + @mock.patch.object(OpenSearchServerlessHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = OpenSearchServerlessCollectionActiveTrigger(collection_id=self.COLLECTION_ID) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "collection_id": self.COLLECTION_ID}) + assert mock_get_waiter().wait.call_count == 1 diff --git a/tests/providers/amazon/aws/utils/test_waiter.py b/tests/providers/amazon/aws/utils/test_waiter.py index d074a45e189bf3..fc8496ce463656 100644 --- a/tests/providers/amazon/aws/utils/test_waiter.py +++ b/tests/providers/amazon/aws/utils/test_waiter.py @@ -41,12 +41,15 @@ def generate_response(state: str) -> dict[str, Any]: def assert_expected_waiter_type(waiter: mock.MagicMock, expected: str): """ There does not appear to be a straight-forward way to assert the type of waiter. - Instead, get the class name and check if it contains the expected name. + + If a Mock of get_waiter() is provided, args[0] will be the waiter name parameter. + If a Mock of get_waiter().wait() is provided, args[0] is the class name of the waiter. + Either way, checking if the expected name is in that string value should result in a match. :param waiter: A mocked Boto3 Waiter object. :param expected: The expected class name of the Waiter object, for example "ClusterActive". """ - assert expected in str(type(waiter.call_args.args[0])) + assert expected in str(waiter.call_args.args[0]) class TestWaiter: diff --git a/tests/providers/amazon/aws/waiters/test_bedrock_agent.py b/tests/providers/amazon/aws/waiters/test_bedrock_agent.py new file mode 100644 index 00000000000000..b0740da4094617 --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_bedrock_agent.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import boto3 +import botocore +import pytest + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook +from airflow.providers.amazon.aws.sensors.bedrock import ( + BedrockIngestionJobSensor, + BedrockKnowledgeBaseActiveSensor, +) + + +class TestBedrockAgentCustomWaiters: + def test_service_waiters(self): + assert "knowledge_base_active" in BedrockAgentHook().list_waiters() + assert "ingestion_job_complete" in BedrockAgentHook().list_waiters() + + +class TestBedrockAgentCustomWaitersBase: + @pytest.fixture(autouse=True) + def mock_conn(self, monkeypatch): + self.client = boto3.client("bedrock-agent") + monkeypatch.setattr(BedrockAgentHook, "conn", self.client) + + +class TestKnowledgeBaseActiveWaiter(TestBedrockAgentCustomWaitersBase): + WAITER_NAME = "knowledge_base_active" + WAITER_ARGS = {"knowledgeBaseId": "kb_id"} + SENSOR = BedrockKnowledgeBaseActiveSensor + + @pytest.fixture + def mock_getter(self): + with mock.patch.object(self.client, "get_knowledge_base") as getter: + yield getter + + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) + def test_knowledge_base_active_complete(self, state, mock_getter): + mock_getter.return_value = {"knowledgeBase": {"status": state}} + + BedrockAgentHook().get_waiter(self.WAITER_NAME).wait(**self.WAITER_ARGS) + + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) + def test_knowledge_base_active_failed(self, state, mock_getter): + mock_getter.return_value = {"knowledgeBase": {"status": state}} + + with pytest.raises(botocore.exceptions.WaiterError): + BedrockAgentHook().get_waiter(self.WAITER_NAME).wait(**self.WAITER_ARGS) + + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) + def test_knowledge_base_active_wait(self, state, mock_getter): + wait = {"knowledgeBase": {"status": state}} + success = {"knowledgeBase": {"status": "ACTIVE"}} + mock_getter.side_effect = [wait, wait, success] + + BedrockAgentHook().get_waiter(self.WAITER_NAME).wait( + **self.WAITER_ARGS, + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + ) + + +class TestIngestionJobWaiter(TestBedrockAgentCustomWaitersBase): + WAITER_NAME = "ingestion_job_complete" + WAITER_ARGS = {"knowledgeBaseId": "kb_id", "dataSourceId": "ds_id", "ingestionJobId": "job_id"} + SENSOR = BedrockIngestionJobSensor + + @pytest.fixture + def mock_getter(self): + with mock.patch.object(self.client, "get_ingestion_job") as getter: + yield getter + + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) + def test_knowledge_base_active_complete(self, state, mock_getter): + mock_getter.return_value = {"ingestionJob": {"status": state}} + + BedrockAgentHook().get_waiter(self.WAITER_NAME).wait(**self.WAITER_ARGS) + + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) + def test_knowledge_base_active_failed(self, state, mock_getter): + mock_getter.return_value = {"ingestionJob": {"status": state}} + + with pytest.raises(botocore.exceptions.WaiterError): + BedrockAgentHook().get_waiter(self.WAITER_NAME).wait(**self.WAITER_ARGS) + + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) + def test_knowledge_base_active_wait(self, state, mock_getter): + wait = {"ingestionJob": {"status": state}} + success = {"ingestionJob": {"status": "COMPLETE"}} + mock_getter.side_effect = [wait, wait, success] + + BedrockAgentHook().get_waiter(self.WAITER_NAME).wait( + **self.WAITER_ARGS, WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} + ) diff --git a/tests/providers/amazon/aws/waiters/test_opensearch_serverless.py b/tests/providers/amazon/aws/waiters/test_opensearch_serverless.py new file mode 100644 index 00000000000000..54ef68eeff055b --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_opensearch_serverless.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import boto3 +import botocore +import pytest + +from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook +from airflow.providers.amazon.aws.sensors.opensearch_serverless import ( + OpenSearchServerlessCollectionActiveSensor, +) + + +class TestOpenSearchServerlessCustomWaiters: + def test_service_waiters(self): + assert "collection_available" in OpenSearchServerlessHook().list_waiters() + + +class TestOpenSearchServerlessCustomWaitersBase: + @pytest.fixture(autouse=True) + def mock_conn(self, monkeypatch): + self.client = boto3.client("opensearchserverless") + monkeypatch.setattr(OpenSearchServerlessHook, "conn", self.client) + + +class TestCollectionAvailableWaiter(TestOpenSearchServerlessCustomWaitersBase): + WAITER_NAME = "collection_available" + + @pytest.fixture + def mock_getter(self): + with mock.patch.object(self.client, "batch_get_collection") as getter: + yield getter + + @pytest.mark.parametrize("state", OpenSearchServerlessCollectionActiveSensor.SUCCESS_STATES) + def test_model_customization_job_complete(self, state, mock_getter): + mock_getter.return_value = {"collectionDetails": [{"status": state}]} + + OpenSearchServerlessHook().get_waiter(self.WAITER_NAME).wait(collection_id="collection_id") + + @pytest.mark.parametrize("state", OpenSearchServerlessCollectionActiveSensor.FAILURE_STATES) + def test_model_customization_job_failed(self, state, mock_getter): + mock_getter.return_value = {"collectionDetails": [{"status": state}]} + + with pytest.raises(botocore.exceptions.WaiterError): + OpenSearchServerlessHook().get_waiter(self.WAITER_NAME).wait(collection_id="collection_id") + + def test_model_customization_job_wait(self, mock_getter): + wait = {"collectionDetails": [{"status": "CREATING"}]} + success = {"collectionDetails": [{"status": "ACTIVE"}]} + mock_getter.side_effect = [wait, wait, success] + + OpenSearchServerlessHook().get_waiter(self.WAITER_NAME).wait( + collection_id="collection_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} + ) diff --git a/tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py b/tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py new file mode 100644 index 00000000000000..959d4ba2c69faf --- /dev/null +++ b/tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py @@ -0,0 +1,527 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +import os.path +import tempfile +from datetime import datetime +from time import sleep +from urllib.request import urlretrieve + +import boto3 +from botocore.exceptions import ClientError +from opensearchpy import ( + AuthorizationException, + AWSV4SignerAuth, + OpenSearch, + RequestsHttpConnection, +) + +from airflow import DAG +from airflow.decorators import task +from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook +from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.hooks.sts import StsHook +from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCreateDataSourceOperator, + BedrockCreateKnowledgeBaseOperator, + BedrockIngestDataOperator, +) +from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator +from airflow.providers.amazon.aws.sensors.bedrock import ( + BedrockIngestionJobSensor, + BedrockKnowledgeBaseActiveSensor, +) +from airflow.providers.amazon.aws.sensors.opensearch_serverless import ( + OpenSearchServerlessCollectionActiveSensor, +) +from airflow.utils.helpers import chain +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder + +############################################################################### +# NOTE: The account running this test must first manually request access to +# the `Titan Embeddings G1 - Text` foundation model via the Bedrock console. +# Gaining access to the model can take 24 hours from the time of request. +############################################################################### + +# Externally fetched variables: +ROLE_ARN_KEY = "ROLE_ARN" +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() + +DAG_ID = "example_bedrock_knowledge_base" + +log = logging.getLogger(__name__) + + +@task +def create_opensearch_policies(bedrock_role_arn: str, collection_name: str, policy_name_suffix: str) -> None: + """ + Create security, network and data access policies within Amazon OpenSearch Serverless. + + :param bedrock_role_arn: Arn of the Bedrock Knowledge Base Execution Role. + :param collection_name: Name of the OpenSearch collection to apply the policies to. + :param policy_name_suffix: EnvironmentID or other unique suffix to append to the policy name. + """ + + encryption_policy_name = f"{naming_prefix}sp-{policy_name_suffix}" + network_policy_name = f"{naming_prefix}np-{policy_name_suffix}" + access_policy_name = f"{naming_prefix}ap-{policy_name_suffix}" + + def _create_security_policy(name, policy_type, policy): + try: + aoss_client.create_security_policy(name=name, policy=json.dumps(policy), type=policy_type) + except ClientError as e: + if e.response["Error"]["Code"] == "ConflictException": + log.info("OpenSearch security policy %s already exists.", name) + raise + + def _create_access_policy(name, policy_type, policy): + try: + aoss_client.create_access_policy(name=name, policy=json.dumps(policy), type=policy_type) + except ClientError as e: + if e.response["Error"]["Code"] == "ConflictException": + log.info("OpenSearch data access policy %s already exists.", name) + raise + + _create_security_policy( + name=encryption_policy_name, + policy_type="encryption", + policy={ + "Rules": [{"Resource": [f"collection/{collection_name}"], "ResourceType": "collection"}], + "AWSOwnedKey": True, + }, + ) + + _create_security_policy( + name=network_policy_name, + policy_type="network", + policy=[ + { + "Rules": [{"Resource": [f"collection/{collection_name}"], "ResourceType": "collection"}], + "AllowFromPublic": True, + } + ], + ) + + _create_access_policy( + name=access_policy_name, + policy_type="data", + policy=[ + { + "Rules": [ + { + "Resource": [f"collection/{collection_name}"], + "Permission": [ + "aoss:CreateCollectionItems", + "aoss:DeleteCollectionItems", + "aoss:UpdateCollectionItems", + "aoss:DescribeCollectionItems", + ], + "ResourceType": "collection", + }, + { + "Resource": [f"index/{collection_name}/*"], + "Permission": [ + "aoss:CreateIndex", + "aoss:DeleteIndex", + "aoss:UpdateIndex", + "aoss:DescribeIndex", + "aoss:ReadDocument", + "aoss:WriteDocument", + ], + "ResourceType": "index", + }, + ], + "Principal": [(StsHook().conn.get_caller_identity()["Arn"]), bedrock_role_arn], + } + ], + ) + + +@task +def create_collection(collection_name: str): + """ + Call the Amazon OpenSearch Serverless API and create a collection with the provided name. + + :param collection_name: The name of the Collection to create. + """ + log.info("\nCreating collection: %s.", collection_name) + return aoss_client.create_collection(name=collection_name, type="VECTORSEARCH")["createCollectionDetail"][ + "id" + ] + + +@task +def create_vector_index(index_name: str, collection_id: str, region: str): + """ + Use the OpenSearchPy client to create the vector index for the Amazon Open Search Serverless Collection. + + :param index_name: The vector index name to create. + :param collection_id: ID of the collection to be indexed. + :param region: Name of the AWS region the collection resides in. + """ + # Build the OpenSearch client + oss_client = OpenSearch( + hosts=[{"host": f"{collection_id}.{region}.aoss.amazonaws.com", "port": 443}], + http_auth=AWSV4SignerAuth(boto3.Session().get_credentials(), region, "aoss"), + use_ssl=True, + verify_certs=True, + connection_class=RequestsHttpConnection, + timeout=300, + ) + index_config = { + "settings": { + "index.knn": "true", + "number_of_shards": 1, + "knn.algo_param.ef_search": 512, + "number_of_replicas": 0, + }, + "mappings": { + "properties": { + "vector": { + "type": "knn_vector", + "dimension": 1536, + "method": {"name": "hnsw", "engine": "faiss", "space_type": "l2"}, + }, + "text": {"type": "text"}, + "text-metadata": {"type": "text"}, + } + }, + } + + retries = 35 + while retries > 0: + try: + response = oss_client.indices.create(index=index_name, body=json.dumps(index_config)) + log.info("Creating index: %s.", response) + break + except AuthorizationException as e: + # Index creation can take up to a minute and there is no (apparent?) way to check the current state. + log.info( + "Access denied; policy permissions have likely not yet propagated, %s tries remaining.", + retries, + ) + log.debug(e) + retries -= 1 + sleep(2) + + +@task +def copy_data_to_s3(bucket: str): + """ + Download some sample data and upload it to 3S. + + :param bucket: Name of the Amazon S3 bucket to send the data to. + """ + + # Monkey patch the list of names available for NamedTempFile so we can pick the names of the downloaded files. + backup_get_candidate_names = tempfile._get_candidate_names # type: ignore[attr-defined] + destinations = iter( + [ + "AMZN-2022-Shareholder-Letter.pdf", + "AMZN-2021-Shareholder-Letter.pdf", + "AMZN-2020-Shareholder-Letter.pdf", + "AMZN-2019-Shareholder-Letter.pdf", + ] + ) + tempfile._get_candidate_names = lambda: destinations # type: ignore[attr-defined] + + # Download the sample data files, save them as named temp files using the names above, and upload to S3. + sources = [ + "https://s2.q4cdn.com/299287126/files/doc_financials/2023/ar/2022-Shareholder-Letter.pdf", + "https://s2.q4cdn.com/299287126/files/doc_financials/2022/ar/2021-Shareholder-Letter.pdf", + "https://s2.q4cdn.com/299287126/files/doc_financials/2021/ar/Amazon-2020-Shareholder-Letter-and-1997-Shareholder-Letter.pdf", + "https://s2.q4cdn.com/299287126/files/doc_financials/2020/ar/2019-Shareholder-Letter.pdf", + ] + + for source in sources: + with tempfile.NamedTemporaryFile(mode="w", prefix="") as data_file: + urlretrieve(source, data_file.name) + S3Hook().conn.upload_file( + Filename=data_file.name, Bucket=bucket, Key=os.path.basename(data_file.name) + ) + + # Revert the monkey patch. + tempfile._get_candidate_names = backup_get_candidate_names # type: ignore[attr-defined] + # Verify the path reversion worked. + with tempfile.NamedTemporaryFile(mode="w", prefix=""): + # If the reversion above did not apply correctly, this will fail with + # a StopIteration error because the iterator will run out of names. + ... + + +@task +def get_collection_arn(collection_id: str): + """ + Return a collection ARN for a given collection ID. + + :param collection_id: ID of the collection to be indexed. + """ + return next( + colxn["arn"] + for colxn in aoss_client.list_collections()["collectionSummaries"] + if colxn["id"] == collection_id + ) + + +# [START howto_operator_bedrock_delete_data_source] +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_data_source(knowledge_base_id: str, data_source_id: str): + """ + Delete the Amazon Bedrock data source created earlier. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto_operator:BedrockDeleteDataSource` + + :param knowledge_base_id: The unique identifier of the knowledge base which the data source is attached to. + :param data_source_id: The unique identifier of the data source to delete. + """ + log.info("Deleting data source %s from Knowledge Base %s.", data_source_id, knowledge_base_id) + bedrock_agent_client.delete_data_source(dataSourceId=data_source_id, knowledgeBaseId=knowledge_base_id) + + +# [END howto_operator_bedrock_delete_data_source] + + +# [START howto_operator_bedrock_delete_knowledge_base] +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_knowledge_base(knowledge_base_id: str): + """ + Delete the Amazon Bedrock knowledge base created earlier. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:BedrockDeleteKnowledgeBase` + + :param knowledge_base_id: The unique identifier of the knowledge base to delete. + """ + log.info("Deleting Knowledge Base %s.", knowledge_base_id) + bedrock_agent_client.delete_knowledge_base(knowledgeBaseId=knowledge_base_id) + + +# [END howto_operator_bedrock_delete_knowledge_base] + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_vector_index(index_name: str, collection_id: str): + """ + Delete the vector index created earlier. + + :param index_name: The name of the vector index to delete. + :param collection_id: ID of the collection to be indexed. + """ + host = f"{collection_id}.{region_name}.aoss.amazonaws.com" + credentials = boto3.Session().get_credentials() + awsauth = AWSV4SignerAuth(credentials, region_name, "aoss") + + # Build the OpenSearch client + oss_client = OpenSearch( + hosts=[{"host": host, "port": 443}], + http_auth=awsauth, + use_ssl=True, + verify_certs=True, + connection_class=RequestsHttpConnection, + timeout=300, + ) + oss_client.indices.delete(index=index_name) + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_collection(collection_id: str): + """ + Delete the OpenSearch collection created earlier. + + :param collection_id: ID of the collection to be indexed. + """ + log.info("Deleting collection %s.", collection_id) + aoss_client.delete_collection(id=collection_id) + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_opensearch_policies(collection_name: str): + """ + Delete the security, network and data access policies created earlier. + + :param collection_name: All policies in the given collection name will be deleted. + """ + + access_policies = aoss_client.list_access_policies( + type="data", resource=[f"collection/{collection_name}"] + )["accessPolicySummaries"] + log.info("Found access policies for %s: %s", collection_name, access_policies) + if not access_policies: + raise Exception("No access policies found?") + for policy in access_policies: + log.info("Deleting access policy for %s: %s", collection_name, policy["name"]) + aoss_client.delete_access_policy(name=policy["name"], type="data") + + for policy_type in ["encryption", "network"]: + policies = aoss_client.list_security_policies( + type=policy_type, resource=[f"collection/{collection_name}"] + )["securityPolicySummaries"] + if not policies: + raise Exception("No security policies found?") + log.info("Found %s security policies for %s: %s", policy_type, collection_name, policies) + for policy in policies: + log.info("Deleting %s security policy for %s: %s", policy_type, collection_name, policy["name"]) + aoss_client.delete_security_policy(name=policy["name"], type=policy_type) + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_id = test_context["ENV_ID"] + + aoss_client = OpenSearchServerlessHook(aws_conn_id=None).conn + bedrock_agent_client = BedrockAgentHook(aws_conn_id=None).conn + + region_name = boto3.session.Session().region_name + + naming_prefix = "bedrock-kb-" + bucket_name = f"{naming_prefix}{env_id}" + index_name = f"{naming_prefix}index-{env_id}" + knowledge_base_name = f"{naming_prefix}{env_id}" + vector_store_name = f"{naming_prefix}{env_id}" + data_source_name = f"{naming_prefix}ds-{env_id}" + + create_bucket = S3CreateBucketOperator(task_id="create_bucket", bucket_name=bucket_name) + + create_opensearch_policies = create_opensearch_policies( + bedrock_role_arn=test_context[ROLE_ARN_KEY], + collection_name=vector_store_name, + policy_name_suffix=env_id, + ) + + collection = create_collection(collection_name=vector_store_name) + + # [START howto_sensor_opensearch_collection_active] + await_collection = OpenSearchServerlessCollectionActiveSensor( + task_id="await_collection", + collection_name=vector_store_name, + ) + # [END howto_sensor_opensearch_collection_active] + + # [START howto_operator_bedrock_create_knowledge_base] + create_knowledge_base = BedrockCreateKnowledgeBaseOperator( + task_id="create_knowledge_base", + name=knowledge_base_name, + embedding_model_arn=f"arn:aws:bedrock:{region_name}::foundation-model/amazon.titan-embed-text-v1", + role_arn=test_context[ROLE_ARN_KEY], + storage_config={ + "type": "OPENSEARCH_SERVERLESS", + "opensearchServerlessConfiguration": { + "collectionArn": get_collection_arn(collection), + "vectorIndexName": index_name, + "fieldMapping": { + "vectorField": "vector", + "textField": "text", + "metadataField": "text-metadata", + }, + }, + }, + ) + # [END howto_operator_bedrock_create_knowledge_base] + create_knowledge_base.wait_for_completion = False + + # [START howto_sensor_bedrock_knowledge_base_active] + await_knowledge_base = BedrockKnowledgeBaseActiveSensor( + task_id="await_knowledge_base", knowledge_base_id=create_knowledge_base.output + ) + # [END howto_sensor_bedrock_knowledge_base_active] + + # [START howto_operator_bedrock_create_data_source] + create_data_source = BedrockCreateDataSourceOperator( + task_id="create_data_source", + knowledge_base_id=create_knowledge_base.output, + name=data_source_name, + bucket_name=bucket_name, + ) + # [END howto_operator_bedrock_create_data_source] + + # [START howto_operator_bedrock_ingest_data] + ingest_data = BedrockIngestDataOperator( + task_id="ingest_data", + knowledge_base_id=create_knowledge_base.output, + data_source_id=create_data_source.output, + ) + # [END howto_operator_bedrock_ingest_data] + ingest_data.wait_for_completion = False + + # [START howto_sensor_bedrock_ingest_data] + await_ingest = BedrockIngestionJobSensor( + task_id="await_ingest", + knowledge_base_id=create_knowledge_base.output, + data_source_id=create_data_source.output, + ingestion_job_id=ingest_data.output, + ) + # [END howto_sensor_bedrock_ingest_data] + + delete_bucket = S3DeleteBucketOperator( + task_id="delete_bucket", + trigger_rule=TriggerRule.ALL_DONE, + bucket_name=bucket_name, + force_delete=True, + ) + + chain( + # TEST SETUP + test_context, + create_bucket, + create_opensearch_policies, + collection, + await_collection, + create_vector_index(index_name=index_name, collection_id=collection, region=region_name), + copy_data_to_s3(bucket=bucket_name), + # TEST BODY + create_knowledge_base, + await_knowledge_base, + create_data_source, + ingest_data, + await_ingest, + delete_data_source( + knowledge_base_id=create_knowledge_base.output, + data_source_id=create_data_source.output, + ), + delete_knowledge_base(knowledge_base_id=create_knowledge_base.output), + # TEST TEARDOWN + delete_vector_index(index_name=index_name, collection_id=collection), + delete_opensearch_policies(collection_name=vector_store_name), + delete_collection(collection_id=collection), + delete_bucket, + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)