diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index f98d6edde306e..e700fb5c408ba 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -304,6 +304,7 @@ ContainerPort contentUrl contextmgr contrib +convertToIri copyable CoreV coroutine @@ -439,6 +440,7 @@ deidentify DeidentifyTemplate del delim +deliverability deltalake denylist dep @@ -988,6 +990,7 @@ longblob lookups lshift lxml +m-NCUs machineTypes macOS mae @@ -1073,6 +1076,7 @@ nat natively nav navbar +NCUs nd ndjson nearText @@ -1101,9 +1105,11 @@ NotFound notificationChannels notin npm +nquads ns ntlm ntpd +ntriples Nullable nullable num @@ -1128,6 +1134,7 @@ Oozie OpenAI openai openapi +opencypher openfaas OpenID openlineage @@ -1321,6 +1328,7 @@ RaG RBAC rbac rc +rdfxml RDS rds readme diff --git a/providers/amazon/docs/operators/neptune_analytics.rst b/providers/amazon/docs/operators/neptune_analytics.rst new file mode 100644 index 0000000000000..ceaefda13a2d0 --- /dev/null +++ b/providers/amazon/docs/operators/neptune_analytics.rst @@ -0,0 +1,148 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +======================== +Amazon Neptune Analytics +======================== + +`Amazon Neptune Analytics `__ is a memory-optimized graph database engine for analytics. With Neptune Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:NeptuneCreateGraphOperator: + +Create a new Neptune Graph +========================== + +To create a new Neptune Analytics Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCreateGraphOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_create_graph] + :end-before: [END howto_operator_neptune_analytics_create_graph] + + +.. _howto/operator:NeptuneDeleteGraphOperator: + +Delete a Neptune Graph +====================== + +To delete an existing Neptune Analytics Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneDeleteGraphOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_delete_graph] + :end-before: [END howto_operator_neptune_analytics_delete_graph] + +.. _howto/operator:NeptuneCreatePrivateGraphEndpointOperator: + +Create a Neptune Graph private endpoint +======================================= + +To create a VPC Endpoint for connecting to an existing Neptune Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCreatePrivateGraphEndpointOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_create_private_endpoint] + :end-before: [END howto_operator_neptune_analytics_create_private_endpoint] + +.. _howto/operator:NeptuneDeletePrivateGraphEndpointOperator: + +Delete a Neptune Graph private endpoint +======================================= + +To delete a VPC Endpoint attached to an existing Neptune Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneDeletePrivateGraphEndpointOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_delete_private_endpoint] + :end-before: [END howto_operator_neptune_analytics_delete_private_endpoint] + +.. _howto/operator:NeptuneCreateGraphWithImportOperator: + +Create a Neptune Graph with a data import task +============================================== + +To create a Neptune Analytics Graph and immediately import data, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCreateGraphWithImportOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_create_graph_with_import] + :end-before: [END howto_operator_neptune_analytics_create_graph_with_import] + +.. _howto/operator:NeptuneStartImportTaskOperator: + +Import data into an existing Neptune Graph +========================================== + +To import data into an existing Neptune Analytics Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneStartImportTaskOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_start_import_task] + :end-before: [END howto_operator_neptune_analytics_start_import_task] + +.. _howto/operator:NeptuneCancelImportTaskOperator: + +Cancel a running import task +============================ + +To cancel an existing import task, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCancelImportTaskOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_cancel_import_task] + :end-before: [END howto_operator_neptune_analytics_cancel_import_task] diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index f396e2184e574..ae026bba5d72e 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -402,6 +402,12 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-amazon/operators/mwaa.rst tags: [aws] + - integration-name: Amazon Neptune Analytics + external-doc-url: https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted.html + logo: /docs/integration-logos/Amazon-Neptune_64.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/neptune_analytics.rst + tags: [aws] - integration-name: Amazon S3 Vectors external-doc-url: https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors.html logo: /docs/integration-logos/Amazon-Simple-Storage-Service-S3_light-bg@4x.png @@ -520,6 +526,7 @@ operators: - integration-name: Amazon Neptune python-modules: - airflow.providers.amazon.aws.operators.neptune + - airflow.providers.amazon.aws.operators.neptune_analytics - integration-name: Amazon S3 Vectors python-modules: - airflow.providers.amazon.aws.operators.s3_vectors @@ -784,6 +791,8 @@ hooks: - integration-name: Amazon Neptune python-modules: - airflow.providers.amazon.aws.hooks.neptune + - airflow.providers.amazon.aws.hooks.neptune_analytics + bundles: - integration-name: Amazon Simple Storage Service (S3) @@ -866,6 +875,7 @@ triggers: - integration-name: Amazon Neptune python-modules: - airflow.providers.amazon.aws.triggers.neptune + - airflow.providers.amazon.aws.triggers.neptune_analytics - integration-name: AWS Database Migration Service python-modules: - airflow.providers.amazon.aws.triggers.dms @@ -981,6 +991,9 @@ extra-links: - airflow.providers.amazon.aws.links.datasync.DataSyncTaskExecutionLink - airflow.providers.amazon.aws.links.ec2.EC2InstanceLink - airflow.providers.amazon.aws.links.ec2.EC2InstanceDashboardLink + - airflow.providers.amazon.aws.links.neptune_analytics.NeptuneGraphLink + - airflow.providers.amazon.aws.links.neptune_analytics.NeptuneImportTaskLink + - airflow.providers.amazon.aws.links.ec2.VpcEndpointLink connection-types: - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook diff --git a/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py b/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py index ae1863bec9c04..e289099d831a5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py @@ -50,3 +50,27 @@ def __reduce__(self): class S3HookUriParseFailure(AirflowException): """When parse_s3_url fails to parse URL, this error is thrown.""" + + +class NeptuneGraphCreationFailedError(AirflowException): + """Raised when a Neptune Analytics graph fails to reach the available state.""" + + +class NeptunePrivateEndpointCreationFailedError(AirflowException): + """Raised when a Neptune Analytics private graph endpoint fails to be created.""" + + +class NeptunePrivateEndpointDeletionFailedError(AirflowException): + """Raised when a Neptune Analytics private graph endpoint fails to be deleted.""" + + +class NeptuneGraphDeletionFailedError(AirflowException): + """Raised when a Neptune Analytics graph deletion encounters an unexpected AWS error.""" + + +class NeptuneImportTaskCancellationFailedError(AirflowException): + """Raised when a Neptune Analytics import task cancellation fails or returns an unexpected status.""" + + +class NeptuneImportTaskFailedError(AirflowException): + """Raised when a Neptune Analytics import task fails to complete successfully.""" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py new file mode 100644 index 0000000000000..8f079b9194b2c --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class NeptuneAnalyticsHook(AwsBaseHook): + """ + Interact with Amazon Neptune Analytics. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs): + kwargs["client_type"] = "neptune-graph" + super().__init__(*args, **kwargs) + + def _get_graph_endpoint_id(self, graph_id: str, vpc_id: str): + """Return the vpc endpoint id for this graph.""" + result = self.conn.get_private_graph_endpoint(graphIdentifier=graph_id, vpcId=vpc_id) + return result.get("vpcEndpointId") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py index 38a23956cddbb..96fb03e9130d4 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py @@ -44,3 +44,14 @@ class EC2InstanceDashboardLink(BaseAwsLink): @staticmethod def format_instance_id_filter(instance_ids: list[str]) -> str: return ",:".join(instance_ids) + + +class VpcEndpointLink(BaseAwsLink): + """Helper class for constructing a VPC Endpoint link.""" + + name = "VPC Endpoint" + key = "_vpc_endpoint" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/vpcconsole/home?region={region_name}#EndpointDetails:vpcEndpointId={endpoint_id}" + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/links/neptune_analytics.py new file mode 100644 index 0000000000000..d3b13e48ad4fb --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/neptune_analytics.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class NeptuneGraphLink(BaseAwsLink): + """Helper class for constructing an Amazon Neptune Analytics Graph Link.""" + + name = "Neptune Graph" + key = "_neptune_graph" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/neptune/home?region={region_name}#analytics-graph-details:id={graph_id}" + + ";tab=connectivity" + ) + + +class NeptuneImportTaskLink(BaseAwsLink): + """Helper class for constructing an Amazon Neptune Analytics import task link.""" + + name = "Neptune Import Task" + key = "_import_task" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/neptune/home?region={region_name}#analytics-import-task-details:id={import_task_id}" + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py new file mode 100644 index 0000000000000..ecfe15d193eef --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -0,0 +1,1025 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from botocore.exceptions import ClientError + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.links.ec2 import VpcEndpointLink +from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink, NeptuneImportTaskLink +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.neptune_analytics import ( + NeptuneGraphAvailableTrigger, + NeptuneGraphDeletedTrigger, + NeptuneGraphPrivateEndpointAvailableTrigger, + NeptuneGraphPrivateEndpointDeletedTrigger, + NeptuneImportTaskCancelledTrigger, + NeptuneImportTaskCompleteTrigger, +) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.providers.common.compat.sdk import conf + +if TYPE_CHECKING: + from airflow.sdk import Context + +from airflow.providers.amazon.aws.exceptions import ( + NeptuneGraphCreationFailedError, + NeptuneGraphDeletionFailedError, + NeptuneImportTaskCancellationFailedError, + NeptuneImportTaskFailedError, + NeptunePrivateEndpointCreationFailedError, + NeptunePrivateEndpointDeletionFailedError, +) + + +class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Creates an empty Amazon Neptune Graph database. + + Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCreateGraphOperator` + + :param graph_name: Name of Neptune graph to create + :param vector_search_config: Specifies the number of dimensions for vector embeddings that will be loaded into the graph. + :param provisioned_memory: The provisioned memory-optimized Neptune Capacity Units (m-NCUs) to use for the graph. + :param public_connectivity: Specifies whether or not the graph can be reachable over the internet. + :param replica_count: The number of replicas in other AZs. + :param deletion_protection: Indicates whether or not to enable deletion protection on the graph. + The graph can't be deleted when deletion protection is enabled. + :param kms_key_id: Specifies a KMS key to use to encrypt data in the new graph. + :param tags: Specifies metadata tags to add to the graph. + :param wait_for_completion: Whether to wait for the graph to start. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the graph to start. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields( + "graph_name", "vector_search_config", "provisioned_memory" + ) + + template_fields_renderers = { + "vector_search_config": "json", + } + + operator_extra_links = (NeptuneGraphLink(),) + + def __init__( + self, + graph_name: str, + vector_search_config: dict, + provisioned_memory: int, + public_connectivity: bool | None = None, + replica_count: int | None = None, + deletion_protection: bool = False, + kms_key_id: str | None = None, + tags: dict | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_name = graph_name + self.vector_search_config = vector_search_config + self.replica_count = replica_count + self.provisioned_memory = provisioned_memory + self.public_connectivity = public_connectivity + self.deletion_protect = deletion_protection + self.kms_key_id = kms_key_id + self.tags = tags + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> dict: + self.log.info("Creating graph %s", self.graph_name) + + create_params = { + "graphName": self.graph_name, + "vectorSearchConfiguration": self.vector_search_config, + "provisionedMemory": self.provisioned_memory, + **{ + k: v + for k, v in { + "replicaCount": self.replica_count, + "publicConnectivity": self.public_connectivity, + "deletionProtection": self.deletion_protect, + "kmsKeyIdentifier": self.kms_key_id, + "tags": self.tags, + }.items() + if v is not None + }, + } + + response = self.hook.conn.create_graph(**create_params) + + self.log.info("Graph %s in status %s", self.graph_name, response.get("status", "Unknown")) + self.graph_id = response.get("id", None) + + graph_url = NeptuneGraphLink.format_str.format( + graph_id=self.graph_id, + aws_domain=NeptuneGraphLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneGraphLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + graph_id=self.graph_id, + ) + self.log.info("You can view this Neptune Graph at : %s", graph_url) + + if self.deferrable: + self.log.info("Deferring until graph %s is available", self.graph_id) + self.defer( + trigger=NeptuneGraphAvailableTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + + if self.wait_for_completion: + self.log.info("Waiting until graph %s is available", self.graph_id) + self.hook.get_waiter("graph_available").wait( + graphIdentifier=self.graph_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"graph_id": self.graph_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptuneGraphCreationFailedError( + validated_event.get( + "message", + f"Neptune graph {validated_event.get('return_key')} creation did not complete successfully", + ) + ) + + self.log.info("Neptune graph %s complete", self.graph_id) + + return {"graph_id": self.graph_id} + + +class NeptuneCreatePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Creates a Neptune Graph private endpoint. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCreatePrivateGraphEndpointOperator` + + :param graph_identifier: Neptune Graph id + :param vpc_id: VPC to create endpoint in + :param subnet_ids: Subnets in which private graph endpoint ENIs are created + :param vpc_security_group_ids: Security groups to be attached to the private graph endpoint + :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields( + "graph_identifier", "vpc_id", "subnet_ids", "vpc_security_group_ids" + ) + + def __init__( + self, + graph_identifier: str, + vpc_id: str | None = None, + subnet_ids: list[str] | None = None, + vpc_security_group_ids: list[str] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_identifier = graph_identifier + self.vpc_id = vpc_id + self.subnet_ids = subnet_ids + self.vpc_security_group_ids = vpc_security_group_ids + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> dict: + self.log.info("Creating private endpoint for graph %s", self.graph_identifier) + + create_params = { + "graphIdentifier": self.graph_identifier, + **{ + k: v + for k, v in { + "vpcId": self.vpc_id, + "subnetIds": self.subnet_ids, + "vpcSecurityGroupIds": self.vpc_security_group_ids, + }.items() + if v is not None + }, + } + + # create the endpoint + + result = self.hook.conn.create_private_graph_endpoint(**create_params) + status = result.get("status", "Unknown") + + self.log.info("Status of endpoint: %s", status) + + if status in ["FAILED"]: + raise NeptunePrivateEndpointCreationFailedError( + f"Private endpoint failed to create for graph {self.graph_identifier}" + ) + + # if VPC not provided, use the one that is returned, which is the default VPC. Required for the waiter + self.vpc_id = result.get("vpcId", self.vpc_id) + + # get the vpce id since it may not be returned immediately + endpoint_id = self.hook._get_graph_endpoint_id(graph_id=self.graph_identifier, vpc_id=self.vpc_id) + + endpoint_url = VpcEndpointLink.format_str.format( + endpoint_id=endpoint_id, + aws_domain=VpcEndpointLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + VpcEndpointLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + endpoint_id=endpoint_id, + ) + self.log.info("You can view this private endpoint at : %s", endpoint_url) + + if self.deferrable: + self.log.info("Deferring until endpoint is available") + self.defer( + trigger=NeptuneGraphPrivateEndpointAvailableTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_identifier, + vpc_id=self.vpc_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + kwargs={"vpc_id": self.vpc_id}, + ) + + if self.wait_for_completion: + self.log.info("Waiting until endpoint is available") + self.hook.get_waiter("private_graph_endpoint_available").wait( + graphIdentifier=self.graph_identifier, + vpcId=self.vpc_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"vpc_endpoint_id": endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptunePrivateEndpointCreationFailedError( + validated_event.get("message", "Endpoint failed to create") + ) + + graph_id = validated_event.get("value") + vpc_id = validated_event.get("vpc_id") + vpc_endpoint_id = self.hook._get_graph_endpoint_id(graph_id=graph_id, vpc_id=vpc_id) + return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": graph_id, "vpc_id": vpc_id} + + +class NeptuneDeletePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Deletes a Neptune Graph private endpoint. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneDeletePrivateGraphEndpointOperator` + + :param graph_identifier: Neptune Graph id + :param vpc_id: VPC where endpoint resides + :param wait_for_completion: Whether to wait for the endpoint to be deleted. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to be deleted. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields("graph_identifier", "vpc_id") + + def __init__( + self, + graph_identifier: str, + vpc_id: str, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_identifier = graph_identifier + self.vpc_id = vpc_id + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> None: + self.log.info("Deleting private endpoint for graph %s", self.graph_identifier) + + result = self.hook.conn.delete_private_graph_endpoint( + graphIdentifier=self.graph_identifier, vpcId=self.vpc_id + ) + + status = result.get("status") + endpoint_id = result.get("vpcEndpointId") + + if status == "FAILED": + raise NeptunePrivateEndpointDeletionFailedError( + f"Failed to delete private endpoint {endpoint_id}" + ) + + if self.deferrable: + self.log.info("Deferring until endpoint %s is deleted", endpoint_id) + self.defer( + trigger=NeptuneGraphPrivateEndpointDeletedTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_identifier, + vpc_id=self.vpc_id, + endpoint_id=endpoint_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: + self.log.info("Waiting until endpoint %s is deleted", endpoint_id) + self.hook.get_waiter("private_graph_endpoint_deleted").wait( + graphIdentifier=self.graph_identifier, + vpcId=self.vpc_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + self.log.info("Endpoint %s deleted", endpoint_id) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptunePrivateEndpointDeletionFailedError( + validated_event.get("message", "Endpoint failed to delete.") + ) + + vpc_endpoint_id = validated_event.get("endpoint_id", "Unknown") + self.log.info("Endpoint id %s deleted", vpc_endpoint_id) + + +class NeptuneDeleteGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Deletes an Amazon Neptune Graph database. + + Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneDeleteGraphOperator` + + :param graph_id: Name of Neptune graph to delete + :param skip_snapshot: Determines whether a final graph snapshot is created before the graph is deleted. If true is specified, no graph snapshot is created. If false is specified, a graph snapshot is created before the graph is deleted. + :param wait_for_completion: Whether to wait for the graph to delete. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the graph to be deleted. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields("graph_id", "skip_snapshot") + + def __init__( + self, + graph_id: str, + skip_snapshot: bool, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_id = graph_id + self.skip_snapshot = skip_snapshot + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context): + self.log.info("Deleting graph %s", self.graph_id) + + try: + self.hook.conn.delete_graph(graphIdentifier=self.graph_id, skipSnapshot=self.skip_snapshot) + except ClientError as e: + # if not found, just exit because there is nothing to delete + if e.response["Error"]["Code"] == "ResourceNotFoundException": + self.log.info("Graph %s not found. Nothing to delete", self.graph_id) + return + raise NeptuneGraphDeletionFailedError(e.response["Error"]) + + if self.deferrable: + self.log.info("Deferring until graph %s is deleted", self.graph_id) + self.defer( + trigger=NeptuneGraphDeletedTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: + self.log.info("Waiting to delete %s", self.graph_id) + + self.hook.get_waiter("graph_deleted").wait( + graphIdentifier=self.graph_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None): + validated_event = validate_execute_complete_event(event) + graph_id = validated_event.get("graph_id") + if validated_event.get("status") != "success": + raise NeptuneGraphDeletionFailedError( + validated_event.get("message", f"Neptune graph {graph_id} deletion failed") + ) + + self.log.info("Neptune graph %s deleted", validated_event.get("graph_id", graph_id)) + + +class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Creates a Neptune Graph and imports data into it. + + Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune Analytics, + you can get insights and find trends by processing large amounts of graph data in seconds. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCreateGraphWithImportOperator` + + :param graph_name: Name of Neptune graph to create + :param vector_search_config: Specifies the number of dimensions for vector embeddings that will be loaded into the graph. + :param source: The source from which to import data. Can be an S3 URI or Neptune database snapshot. + :param role_arn: The ARN of the IAM role that Neptune Analytics can assume to access the data source. + :param blank_node_handling: The method to handle blank nodes in the dataset. Options include 'convertToIri' or other handling strategies. + :param parquet_type: The type of Parquet files in the data source (if applicable). + :param format: The format of the data to be imported (e.g., 'csv', 'opencypher', 'ntriples', 'nquads', 'rdfxml', 'turtle'). + :param min_provisioned_memory: The minimum provisioned memory for the graph in GBs. + :param max_provisioned_memory: The maximum provisioned memory for the graph in GBs. + :param fail_on_error: If True, the import will fail if any errors are encountered. If False, the import will continue despite errors. + :param public_connectivity: Specifies whether or not the graph can be reachable over the internet. + :param replica_count: The number of replicas in other AZs. + :param deletion_protection: Indicates whether or not to enable deletion protection on the graph. + The graph can't be deleted when deletion protection is enabled. (default: False) + :param kms_key_id: Specifies a KMS key to use to encrypt data in the new graph. + :param tags: Specifies metadata tags to add to the graph. + :param import_options: Contains options for controlling the import process. + :param wait_for_completion: Whether to wait for the graph to be created and data imported. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 30) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 60) + :param deferrable: If True, the operator will wait asynchronously for the graph to be created and data imported. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields( + "graph_name", "vector_search_config", "source", "role_arn" + ) + + template_fields_renderers = { + "vector_search_config": "json", + } + + operator_extra_links = ( + NeptuneImportTaskLink(), + NeptuneGraphLink(), + ) + + def __init__( + self, + graph_name: str, + vector_search_config: dict, + source: str, + role_arn: str, + blank_node_handling: str | None = None, + parquet_type: str | None = None, + format: str | None = None, + min_provisioned_memory: int | None = None, + max_provisioned_memory: int | None = None, + fail_on_error: bool | None = None, + public_connectivity: bool | None = None, + replica_count: int | None = None, + deletion_protection: bool | None = None, + kms_key_id: str | None = None, + tags: dict | None = None, + import_options: dict | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_name = graph_name + self.vector_search_config = vector_search_config + self.source = source + self.role_arn = role_arn + self.blank_node_handling = blank_node_handling + self.parquet_type = parquet_type + self.format = format + self.min_provisioned_memory = min_provisioned_memory + self.max_provisioned_memory = max_provisioned_memory + self.fail_on_error = fail_on_error + self.public_connectivity = public_connectivity + self.replica_count = replica_count + self.deletion_protect = deletion_protection + self.kms_key_id = kms_key_id + self.tags = tags + self.import_options = import_options + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> dict: + self.log.info("Creating graph %s with import", self.graph_name) + + # Build the import options + import_options = { + "neptune-analytics:blank-node-handling": self.blank_node_handling, + "neptune-analytics:parquet-type": self.parquet_type, + } + + # Remove None values from import_options + import_options = {k: v for k, v in import_options.items() if v is not None} + + # Merge with user-provided import_options + if self.import_options: + import_options.update(self.import_options) + + create_params = { + "graphName": self.graph_name, + "vectorSearchConfiguration": self.vector_search_config, + "source": self.source, + "roleArn": self.role_arn, + **{ + k: v + for k, v in { + "format": self.format, + "minProvisionedMemory": self.min_provisioned_memory, + "maxProvisionedMemory": self.max_provisioned_memory, + "failOnError": self.fail_on_error, + "replicaCount": self.replica_count, + "publicConnectivity": self.public_connectivity, + "deletionProtection": self.deletion_protect, + "kmsKeyIdentifier": self.kms_key_id, + "tags": self.tags, + "importOptions": import_options if import_options else None, + }.items() + if v is not None + }, + } + + response = self.hook.conn.create_graph_using_import_task(**create_params) + + self.log.info("Graph %s import task in status %s", self.graph_name, response.get("status", "Unknown")) + self.graph_id = response.get("graphId", None) + import_task_id = response.get("taskId") + + graph_url = NeptuneGraphLink.format_str.format( + graph_id=self.graph_id, + aws_domain=NeptuneGraphLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneGraphLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + graph_id=self.graph_id, + ) + + import_task_url = NeptuneImportTaskLink.format_str.format( + import_task_id=import_task_id, + aws_domain=NeptuneImportTaskLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneImportTaskLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + import_task_id=import_task_id, + ) + + self.log.info("You can view this import task at : %s", import_task_url) + + self.log.info("You can view this Neptune Graph at : %s", graph_url) + + if self.deferrable: + self.log.info("Deferring until graph %s is available", self.graph_id) + self.defer( + trigger=NeptuneGraphAvailableTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="defer_wait_for_task", + kwargs={"import_task_id": import_task_id}, + ) + + if self.wait_for_completion: + self.log.info("Waiting until graph %s is available", self.graph_id) + self.hook.get_waiter("graph_available").wait( + graphIdentifier=self.graph_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + # Once the graph is available, wait for the task to complete + + self.log.info("Waiting for import task %s", import_task_id) + self.hook.get_waiter("import_task_successful").wait( + taskIdentifier=import_task_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"graph_id": self.graph_id} + + def defer_wait_for_task( + self, context: Context, event: dict[str, Any] | None = None, import_task_id: str | None = None + ) -> None: + """Defers for import task completion.""" + validated_event = validate_execute_complete_event(event) + graph_id = validated_event.get("value") + + if validated_event.get("status") != "success": + raise NeptuneGraphCreationFailedError( + validated_event.get("message", f"Neptune graph {graph_id} did not become available") + ) + + if import_task_id: + self.log.info("Deferring for import task %s completion", import_task_id) + self.defer( + trigger=NeptuneImportTaskCompleteTrigger( + import_task_id=import_task_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + kwargs={"graph_id": graph_id}, + ) + + def execute_complete( + self, context: Context, event: dict[str, Any] | None = None, graph_id: str | None = None + ) -> dict[str, Any]: + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptuneGraphCreationFailedError( + validated_event.get( + "message", f"Neptune graph {graph_id} import did not complete successfully" + ) + ) + + self.log.info("Import complete for graph %s", graph_id) + return {"graph_id": graph_id} + + +class NeptuneStartImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Starts a bulk data import task to load data into an empty Neptune graph. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneStartImportTaskOperator` + + :param graph_identifier: Graph Id of target Neptune Graph + :param role_arn: IAM role ARN granting access to source data + :param source: URL identifying the source data location. + :param blank_node_handling: Method to handle blank nodes in dataset. + :param fail_on_error: If set to true, the task halts when an import error is encountered. If set to false, the task skips the data that caused the error and continues if possible. + :param format: Specifies the format of the Amazon S3 data to be imported. + :param import_options: Options on how to perform an import + :param parquet_type: Parquet type of import task + :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields( + "graph_identifier", "role_arn", "source", "import_options" + ) + template_fields_renderers = { + "import_options": "json", + } + operator_extra_links = (NeptuneImportTaskLink(),) + + def __init__( + self, + graph_identifier: str, + role_arn: str, + source: str, + blank_node_handling: str | None = None, + fail_on_error: bool = True, + format: str | None = None, + import_options: dict | None = None, + parquet_type: str | None = "COLUMNAR", + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_identifier = graph_identifier + self.role_arn = role_arn + self.source = source + self.blank_node_handling = blank_node_handling + self.fail_on_error = fail_on_error + self.format = format + self.import_options = import_options + self.parquet_type = parquet_type + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> dict: + self.log.info("Starting data import to graph %s", self.graph_identifier) + + create_params = { + "graphIdentifier": self.graph_identifier, + "roleArn": self.role_arn, + "source": self.source, + **{ + k: v + for k, v in { + "blankNodeHandling": self.blank_node_handling, + "failOnError": self.fail_on_error, + "format": self.format, + "importOptions": self.import_options, + "parquetType": self.parquet_type, + }.items() + if v is not None + }, + } + + response = self.hook.conn.start_import_task(**create_params) + import_task_id = response.get("taskId") + self.log.info("Import task %s started for graph %s", import_task_id, self.graph_identifier) + + # Create the console link + import_task_url = NeptuneImportTaskLink.format_str.format( + import_task_id=import_task_id, + aws_domain=NeptuneImportTaskLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneImportTaskLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + import_task_id=import_task_id, + ) + self.log.info("You can view this import task at : %s", import_task_url) + + if self.deferrable: + self.log.info("Deferring until import task %s completes", import_task_id) + self.defer( + trigger=NeptuneImportTaskCompleteTrigger( + import_task_id=import_task_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + + if self.wait_for_completion: + self.log.info("Waiting for import task %s to complete", import_task_id) + self.hook.get_waiter("import_task_successful").wait( + taskIdentifier=import_task_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"import_task_id": import_task_id, "graph_id": self.graph_identifier} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptuneImportTaskFailedError( + validated_event.get("message", "Import task did not complete successfully") + ) + + task_id = validated_event.get("import_task_id", "") + self.log.info("Import task %s completed", task_id) + return {"graph_id": self.graph_identifier, "import_task_id": task_id} + + +class NeptuneCancelImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Cancels an active Neptune Graph import task. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCancelImportTaskOperator` + + :param import_task_id: Neptune Graph import task id to cancel. + :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields("import_task_id") + + def __init__( + self, + import_task_id: str, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.import_task_id = import_task_id + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> dict: + self.log.info("Cancelling import task %s", self.import_task_id) + + response = self.hook.conn.cancel_import_task(taskIdentifier=self.import_task_id) + + self.log.info("Import task %s status is %s", self.import_task_id, response.get("status", "Unknown")) + + if self.deferrable: + self.log.info("Deferring until import task %s is cancelled", self.import_task_id) + self.defer( + trigger=NeptuneImportTaskCancelledTrigger( + task_identifier=self.import_task_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + + if self.wait_for_completion: + self.log.info("Waiting for import task %s to be cancelled", self.import_task_id) + self.hook.get_waiter("import_task_cancelled").wait( + taskIdentifier=self.import_task_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"import_task_id": self.import_task_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptuneImportTaskCancellationFailedError( + validated_event.get("message", "Error while waiting for Neptune import task cancellation") + ) + + task_id = validated_event.get("value", "") + self.log.info("Import task %s cancelled", task_id) + return {"import_task_id": task_id} diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py new file mode 100644 index 0000000000000..50a4d1edc5abe --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + +class NeptuneGraphAvailableTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune graph is available. + + :param graph_id: Graph ID to poll. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id}, + waiter_name="graph_available", + waiter_args={"graphIdentifier": graph_id}, + failure_message="Failed to create Neptune graph", + status_message="Status of Neptune graph is", + status_queries=["status"], + return_key="graph_id", + return_value=graph_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class NeptuneGraphPrivateEndpointAvailableTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune Graph private endpoint is available. + + :param graph_id: Graph Id waiting for the endpoint + :param vpc_id: VPC id where endpoint is creating + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + vpc_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id, "vpc_id": vpc_id}, + waiter_name="private_graph_endpoint_available", + waiter_args={"graphIdentifier": graph_id, "vpcId": vpc_id}, + failure_message="Failed to create Neptune graph endpoint", + status_message="Status of Neptune graph endpoint is", + status_queries=["status"], + return_key="graph_id", + return_value=graph_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class NeptuneGraphPrivateEndpointDeletedTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune Graph private endpoint is deleted. + + :param graph_id: Graph Id of the endpoint + :param vpc_id: VPC id where endpoint resides + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + vpc_id: str, + endpoint_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id, "vpc_id": vpc_id, "endpoint_id": endpoint_id}, + waiter_name="private_graph_endpoint_deleted", + waiter_args={"graphIdentifier": graph_id, "vpcId": vpc_id}, + failure_message="Failed to delete Neptune graph endpoint", + status_message="Status of Neptune graph endpoint is", + status_queries=["status"], + return_key="endpoint_id", + return_value=endpoint_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class NeptuneGraphDeletedTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune Graph is deleted. + + :param graph_id: Graph Id to be deleted + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id}, + waiter_name="graph_deleted", + waiter_args={"graphIdentifier": graph_id}, + failure_message="Failed to delete Neptune graph", + status_message="Status of Neptune graph is", + status_queries=["status"], + return_key="graph_id", + return_value=graph_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class NeptuneImportTaskCompleteTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune import task successfully completes. + + :param task_id: Import task id to monitor + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + import_task_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"import_task_id": import_task_id}, + waiter_name="import_task_successful", + waiter_args={"taskIdentifier": import_task_id}, + failure_message="Import task failed", + status_message="Status of import task is", + status_queries=["status"], + return_key="import_task_id", + return_value=import_task_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class NeptuneImportTaskCancelledTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune import task is successfully cancelled. + + :param task_id: Import task id to monitor. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + task_identifier: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"task_identifier": task_identifier}, + waiter_name="import_task_cancelled", + waiter_args={"taskIdentifier": task_identifier}, + failure_message="Import task cancellation failed", + status_message="Status of import task is", + status_queries=["status"], + return_key="import_task_id", + return_value=task_identifier, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/waiters/neptune_analytics.json b/providers/amazon/src/airflow/providers/amazon/aws/waiters/neptune_analytics.json new file mode 100644 index 0000000000000..e5a8e82712f3f --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/neptune_analytics.json @@ -0,0 +1,38 @@ +{ + "version": 2, + "waiters": { + + "import_task_cancelled":{ + "operation": "GetImportTask", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "status", + "expected": "SUCCEEDED", + "state": "success" + }, + { + "matcher": "path", + "argument": "status", + "expected": "CANCELLED", + "state": "success" + }, + { + "matcher": "path", + "argument": "status", + "expected": "ERROR_ENCOUNTERED", + "state": "error" + }, + { + "matcher": "path", + "argument": "status", + "expected": "FAILED", + "state": "success" + } + ] + } + + + }} diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 99721f3c481b9..3d158a94523b0 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -372,6 +372,13 @@ def get_provider_info(): "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/mwaa.rst"], "tags": ["aws"], }, + { + "integration-name": "Amazon Neptune Analytics", + "external-doc-url": "https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted.html", + "logo": "/docs/integration-logos/Amazon-Neptune_64.png", + "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/neptune_analytics.rst"], + "tags": ["aws"], + }, { "integration-name": "Amazon S3 Vectors", "external-doc-url": "https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors.html", @@ -529,7 +536,10 @@ def get_provider_info(): }, { "integration-name": "Amazon Neptune", - "python-modules": ["airflow.providers.amazon.aws.operators.neptune"], + "python-modules": [ + "airflow.providers.amazon.aws.operators.neptune", + "airflow.providers.amazon.aws.operators.neptune_analytics", + ], }, { "integration-name": "Amazon S3 Vectors", @@ -874,7 +884,10 @@ def get_provider_info(): }, { "integration-name": "Amazon Neptune", - "python-modules": ["airflow.providers.amazon.aws.hooks.neptune"], + "python-modules": [ + "airflow.providers.amazon.aws.hooks.neptune", + "airflow.providers.amazon.aws.hooks.neptune_analytics", + ], }, ], "bundles": [ @@ -987,7 +1000,10 @@ def get_provider_info(): }, { "integration-name": "Amazon Neptune", - "python-modules": ["airflow.providers.amazon.aws.triggers.neptune"], + "python-modules": [ + "airflow.providers.amazon.aws.triggers.neptune", + "airflow.providers.amazon.aws.triggers.neptune_analytics", + ], }, { "integration-name": "AWS Database Migration Service", @@ -1149,6 +1165,9 @@ def get_provider_info(): "airflow.providers.amazon.aws.links.datasync.DataSyncTaskExecutionLink", "airflow.providers.amazon.aws.links.ec2.EC2InstanceLink", "airflow.providers.amazon.aws.links.ec2.EC2InstanceDashboardLink", + "airflow.providers.amazon.aws.links.neptune_analytics.NeptuneGraphLink", + "airflow.providers.amazon.aws.links.neptune_analytics.NeptuneImportTaskLink", + "airflow.providers.amazon.aws.links.ec2.VpcEndpointLink", ], "connection-types": [ { diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py new file mode 100644 index 0000000000000..79c64f951ce3d --- /dev/null +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -0,0 +1,346 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import json +import time +from datetime import datetime + +import boto3 + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.operators.neptune_analytics import ( + NeptuneCancelImportTaskOperator, + NeptuneCreateGraphOperator, + NeptuneCreateGraphWithImportOperator, + NeptuneCreatePrivateGraphEndpointOperator, + NeptuneDeleteGraphOperator, + NeptuneDeletePrivateGraphEndpointOperator, + NeptuneStartImportTaskOperator, +) +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3CreateObjectOperator, + S3DeleteBucketOperator, +) +from airflow.providers.common.compat.sdk import DAG, chain, task + +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +from system.amazon.aws.utils import SystemTestContextBuilder + +DAG_ID = "example_neptune_analytics" + +sys_test_context_task = SystemTestContextBuilder().build() + +# Minimal OpenCypher CSV data for import testing. +NODES_CSV = """~id,~label,name:String +n1,Person,Alice +n2,Person,Bob +""" + +EDGES_CSV = """~id,~from,~to,~label +e1,n1,n2,KNOWS +""" + +NEPTUNE_ANALYTICS_TRUST_POLICY = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "neptune-graph.amazonaws.com"}, + "Action": "sts:AssumeRole", + } + ], + } +) + +S3_READ_POLICY_DOCUMENT = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": ["s3:GetObject", "s3:ListBucket"], + "Resource": ["arn:aws:s3:::*", "arn:aws:s3:::*/*"], + } + ], + } +) + + +@task +def create_neptune_import_role(role_name: str) -> str: + iam_client = boto3.client("iam") + iam_client.create_role( + RoleName=role_name, + AssumeRolePolicyDocument=NEPTUNE_ANALYTICS_TRUST_POLICY, + Description="Role for Neptune Analytics import system test", + ) + iam_client.put_role_policy( + RoleName=role_name, + PolicyName="NeptuneAnalyticsS3Access", + PolicyDocument=S3_READ_POLICY_DOCUMENT, + ) + role = iam_client.get_role(RoleName=role_name) + time.sleep(60) # Wait for IAM eventual consistency (role + inline policy propagation) + return role["Role"]["Arn"] + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_neptune_import_role(role_name: str) -> None: + iam_client = boto3.client("iam") + with contextlib.suppress(iam_client.exceptions.NoSuchEntityException): + iam_client.delete_role_policy(RoleName=role_name, PolicyName="NeptuneAnalyticsS3Access") + iam_client.delete_role(RoleName=role_name) + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_graph_if_exists(graph_name: str) -> None: + """Safety net to clean up the graph in case a previous task failed.""" + hook = NeptuneAnalyticsHook() + with contextlib.suppress(Exception): + # List graphs and find by name + paginator = hook.conn.get_paginator("list_graphs") + for page in paginator.paginate(): + for graph in page.get("graphs", []): + if graph.get("name") == graph_name: + graph_id = graph["id"] + + # Delete any attached private graph endpoints before deleting the graph + endpoints_paginator = hook.conn.get_paginator("list_private_graph_endpoints") + for ep_page in endpoints_paginator.paginate(graphIdentifier=graph_id): + for endpoint in ep_page.get("privateGraphEndpoints", []): + vpc_id = endpoint["vpcId"] + hook.conn.delete_private_graph_endpoint(graphIdentifier=graph_id, vpcId=vpc_id) + hook.conn.get_waiter("private_graph_endpoint_deleted").wait( + graphIdentifier=graph_id, + vpcId=vpc_id, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + + # Disable deletion protection if enabled + hook.conn.update_graph(graphIdentifier=graph_id, deletionProtection=False) + + hook.conn.delete_graph(graphIdentifier=graph_id, skipSnapshot=True) + hook.conn.get_waiter("graph_deleted").wait( + graphIdentifier=graph_id, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + return + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, +) as dag: + test_context = sys_test_context_task() + + env_id = test_context["ENV_ID"] + graph_name = f"{env_id}-graph" + import_graph_name = f"{env_id}-import-graph" + bucket_name = f"{env_id}-neptune-analytics" + import_role_name = f"{env_id}-neptune-import" + region = boto3.session.Session().region_name + + # --- TEST SETUP --- + + create_bucket = S3CreateBucketOperator( + task_id="create_bucket", + bucket_name=bucket_name, + ) + + upload_nodes = S3CreateObjectOperator( + task_id="upload_nodes", + s3_bucket=bucket_name, + s3_key="data/nodes.csv", + data=NODES_CSV, + replace=True, + ) + + upload_edges = S3CreateObjectOperator( + task_id="upload_edges", + s3_bucket=bucket_name, + s3_key="data/edges.csv", + data=EDGES_CSV, + replace=True, + ) + + create_role = create_neptune_import_role(import_role_name) + + # --- TEST BODY --- + + # [START howto_operator_neptune_analytics_create_graph] + create_graph = NeptuneCreateGraphOperator( + task_id="create_graph", + graph_name=graph_name, + vector_search_config={"dimension": 128}, + provisioned_memory=32, + public_connectivity=True, + replica_count=0, + deletion_protection=False, + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_create_graph] + + # [START howto_operator_neptune_analytics_create_private_endpoint] + create_endpoint = NeptuneCreatePrivateGraphEndpointOperator( + task_id="create_endpoint", + graph_identifier="{{ ti.xcom_pull(task_ids='create_graph')['graph_id']}}", + wait_for_completion=True, + ) + # [END howto_operator_neptune_analytics_create_private_endpoint] + + # [START howto_operator_neptune_analytics_delete_private_endpoint] + delete_endpoint = NeptuneDeletePrivateGraphEndpointOperator( + task_id="delete_endpoint", + graph_identifier="{{ ti.xcom_pull(task_ids='create_graph')['graph_id'] }}", + vpc_id="{{ ti.xcom_pull(task_ids='create_endpoint')['vpc_id'] }}", + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_delete_private_endpoint] + + # [START howto_operator_neptune_analytics_start_import_task] + start_import = NeptuneStartImportTaskOperator( + task_id="start_import", + graph_identifier="{{ ti.xcom_pull(task_ids='create_graph')['graph_id'] }}", + role_arn=create_role, + source=f"s3://{bucket_name}/data/", + format="CSV", + fail_on_error=True, + wait_for_completion=False, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_start_import_task] + + # [START howto_operator_neptune_analytics_cancel_import_task] + cancel_import = NeptuneCancelImportTaskOperator( + task_id="cancel_import", + import_task_id="{{ ti.xcom_pull(task_ids='start_import')['import_task_id']}}", + wait_for_completion=True, + aws_conn_id="aws_default", + ) + # [END howto_operator_neptune_analytics_cancel_import_task] + + # [START howto_operator_neptune_analytics_delete_graph] + delete_graph = NeptuneDeleteGraphOperator( + task_id="delete_graph", + graph_id="{{ ti.xcom_pull(task_ids='create_graph')['graph_id'] }}", + skip_snapshot=True, + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_delete_graph] + + # [START howto_operator_neptune_analytics_create_graph_with_import] + create_graph_with_import = NeptuneCreateGraphWithImportOperator( + task_id="create_graph_with_import", + graph_name=import_graph_name, + vector_search_config={"dimension": 128}, + source=f"s3://{bucket_name}/data/", + role_arn=create_role, + format="CSV", + fail_on_error=True, + public_connectivity=True, + replica_count=0, + deletion_protection=False, + min_provisioned_memory=32, + max_provisioned_memory=32, + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_create_graph_with_import] + + # [START howto_operator_neptune_analytics_delete_import_graph] + delete_import_graph = NeptuneDeleteGraphOperator( + task_id="delete_import_graph", + graph_id="{{ ti.xcom_pull(task_ids='create_graph_with_import')['graph_id'] }}", + skip_snapshot=True, + wait_for_completion=True, + deferrable=False, + trigger_rule=TriggerRule.ALL_DONE, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_delete_import_graph] + + # --- TEST TEARDOWN --- + + delete_bucket = S3DeleteBucketOperator( + task_id="delete_bucket", + trigger_rule=TriggerRule.ALL_DONE, + bucket_name=bucket_name, + force_delete=True, + ) + + delete_role = delete_neptune_import_role(import_role_name) + + cleanup_graph = delete_graph_if_exists.override(task_id="cleanup_graph")(graph_name) + cleanup_import_graph = delete_graph_if_exists.override(task_id="cleanup_import_graph")(import_graph_name) + + chain( + # TEST SETUP + test_context, + create_bucket, + [upload_nodes, upload_edges], + create_role, + # TEST BODY: Create graph, import data, then delete + create_graph, + create_endpoint, + start_import, + cancel_import, + delete_endpoint, + delete_graph, + # TEST BODY: Create graph with import, then delete + create_graph_with_import, + delete_import_graph, + # TEST TEARDOWN + [cleanup_graph, cleanup_import_graph], + delete_bucket, + delete_role, + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py new file mode 100644 index 0000000000000..0e22d6361bbab --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Generator +from unittest import mock + +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook + + +@pytest.fixture +def neptune_hook() -> Generator[NeptuneAnalyticsHook, None, None]: + """Returns a NeptuneAnalyticsHook mocked with moto""" + with mock_aws(): + yield NeptuneAnalyticsHook(aws_conn_id="aws_default") + + +class TestNeptuneAnalyticsHook: + def test_get_conn_returns_a_boto3_connection(self): + hook = NeptuneAnalyticsHook(aws_conn_id="aws_default") + assert hook.get_conn() is not None + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_get_graph_endpoint_id(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": "vpce-12345", + } + + hook = NeptuneAnalyticsHook(aws_conn_id="aws_default") + result = hook._get_graph_endpoint_id(graph_id="g-abc123", vpc_id="vpc-99999") + + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier="g-abc123", vpcId="vpc-99999" + ) + assert result == "vpce-12345" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_get_graph_endpoint_id_missing_key(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = {} + + hook = NeptuneAnalyticsHook(aws_conn_id="aws_default") + result = hook._get_graph_endpoint_id(graph_id="g-abc123", vpc_id="vpc-99999") + + assert result is None diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py index ff2f48e9be174..9aa8e4ce65904 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.providers.amazon.aws.links.ec2 import EC2InstanceDashboardLink, EC2InstanceLink +from airflow.providers.amazon.aws.links.ec2 import EC2InstanceDashboardLink, EC2InstanceLink, VpcEndpointLink from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase @@ -85,3 +85,30 @@ def test_extra_link(self, mock_supervisor_comms): aws_partition="aws", instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS), ) + + +class TestVpcEndpointLink(BaseAwsLinksTestCase): + link_class = VpcEndpointLink + + ENDPOINT_ID = "vpce-0123456789abcdef0" + + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.send.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "endpoint_id": self.ENDPOINT_ID, + }, + ) + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/vpcconsole/home" + f"?region=us-east-1#EndpointDetails:vpcEndpointId={self.ENDPOINT_ID}" + ), + region_name="us-east-1", + aws_partition="aws", + endpoint_id=self.ENDPOINT_ID, + ) diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/links/test_neptune_analytics.py new file mode 100644 index 0000000000000..d332a916125d5 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/links/test_neptune_analytics.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + +pytestmark = pytest.mark.db_test + + +class TestNeptuneGraphLink(BaseAwsLinksTestCase): + link_class = NeptuneGraphLink + + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.send.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "graph_id": "g-fake123456", + }, + ) + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/neptune/home?region=us-east-1" + "#analytics-graph-details:id=g-fake123456;tab=connectivity" + ), + region_name="us-east-1", + aws_partition="aws", + graph_id="g-fake123456", + ) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py new file mode 100644 index 0000000000000..fdc0eeec9b1a8 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -0,0 +1,1277 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from botocore.exceptions import ClientError + +from airflow.providers.amazon.aws.exceptions import ( + NeptuneGraphDeletionFailedError, + NeptunePrivateEndpointCreationFailedError, + NeptunePrivateEndpointDeletionFailedError, +) +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink, NeptuneImportTaskLink +from airflow.providers.amazon.aws.operators.neptune_analytics import ( + NeptuneCancelImportTaskOperator, + NeptuneCreateGraphOperator, + NeptuneCreateGraphWithImportOperator, + NeptuneCreatePrivateGraphEndpointOperator, + NeptuneDeleteGraphOperator, + NeptuneDeletePrivateGraphEndpointOperator, + NeptuneStartImportTaskOperator, +) +from airflow.providers.amazon.aws.triggers.neptune_analytics import ( + NeptuneGraphAvailableTrigger, + NeptuneImportTaskCompleteTrigger, +) +from airflow.providers.common.compat.sdk import TaskDeferred + +GRAPH_NAME = "test_graph" +GRAPH_ID = "test-graph-id" +VPC_ID = "vpc-12345" +SUBNET_IDS = ["subnet-1", "subnet-2"] +SECURITY_GROUP_IDS = ["sg-1", "sg-2"] +ENDPOINT_ID = "vpce-12345" +SOURCE_S3_URI = "s3://my-bucket/my-data/" +ROLE_ARN = "arn:aws:iam::123456789012:role/NeptuneImportRole" + + +class TestNeptuneCreateGraphOperator: + def test_template_fields(self): + # Verify template_fields includes the expected fields + fields = NeptuneCreateGraphOperator.template_fields + assert "graph_name" in fields + assert "vector_search_config" in fields + assert "provisioned_memory" in fields + + def test_template_fields_renderers(self): + assert NeptuneCreateGraphOperator.template_fields_renderers == {"vector_search_config": "json"} + + def test_operator_extra_links(self): + + assert len(NeptuneCreateGraphOperator.operator_extra_links) == 1 + assert isinstance(NeptuneCreateGraphOperator.operator_extra_links[0], NeptuneGraphLink) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + ) + + assert operator.public_connectivity is None + assert operator.replica_count is None + assert operator.deletion_protect is False + assert operator.kms_key_id is None + assert operator.tags is None + + operator.execute(None) + + mock_conn.create_graph.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"test": 123}, + provisionedMemory=16, + deletionProtection=False, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + public_connectivity=True, + replica_count=3, + kms_key_id="test-key", + tags={"key1": "test"}, + deletion_protection=True, + ) + + assert operator.public_connectivity is True + assert operator.replica_count == 3 + assert operator.deletion_protect is True + assert operator.kms_key_id == "test-key" + assert operator.tags == {"key1": "test"} + + operator.execute(None) + + mock_conn.create_graph.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"test": 123}, + replicaCount=3, + publicConnectivity=True, + provisionedMemory=16, + deletionProtection=True, + kmsKeyIdentifier="test-key", + tags={"key1": "test"}, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph(self, mock_hook_get_waiter, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + provisioned_memory=16, + vector_search_config={"test": 123}, + wait_for_completion=False, + ) + resp = operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + assert "graph_id" in resp + assert resp["graph_id"] == GRAPH_ID + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph_wait_for_completion(self, mock_hook_get_waiter, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + provisioned_memory=16, + vector_search_config={"test": 123}, + wait_for_completion=True, + ) + resp = operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("graph_available") + assert "graph_id" in resp + assert resp["graph_id"] == GRAPH_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_persist_called_with_correct_args(self, mock_conn): + """Test that NeptuneGraphLink.persist is called with the correct arguments.""" + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + wait_for_completion=False, + ) + + mock_context = mock.MagicMock() + with mock.patch( + "airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist" + ) as mock_persist: + operator.execute(mock_context) + + mock_persist.assert_called_once_with( + context=mock_context, + operator=operator, + region_name=mock.ANY, + aws_partition=mock.ANY, + graph_id=GRAPH_ID, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_deferrable_defers_with_graph_available_trigger(self, mock_conn, mock_persist): + """Test that deferrable mode defers with NeptuneGraphAvailableTrigger.""" + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneGraphAvailableTrigger) + assert exc_info.value.method_name == "execute_complete" + + +class TestNeptuneCreatePrivateGraphEndpointOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id is None + assert operator.subnet_ids is None + assert operator.vpc_security_group_ids is None + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + result = operator.execute(None) + + mock_conn.create_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + ) + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + assert result is not None + assert result["vpc_endpoint_id"] == ENDPOINT_ID + assert result["graph_id"] == GRAPH_ID + assert result["vpc_id"] == VPC_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + subnet_ids=SUBNET_IDS, + vpc_security_group_ids=SECURITY_GROUP_IDS, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.subnet_ids == SUBNET_IDS + assert operator.vpc_security_group_ids == SECURITY_GROUP_IDS + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.create_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + subnetIds=SUBNET_IDS, + vpcSecurityGroupIds=SECURITY_GROUP_IDS, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_endpoint_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=False, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=True, + ) + result = operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("private_graph_endpoint_available") + mock_hook_get_waiter.return_value.wait.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_create_endpoint_sets_vpc_id_from_response(self, mock_conn): + """When vpc_id is not provided, the operator should use the vpc_id from the API response.""" + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + ) + + assert operator.vpc_id is None + result = operator.execute(None) + + # vpc_id should be set from the create response + assert operator.vpc_id == VPC_ID + assert result["vpc_id"] == VPC_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_create_endpoint_failed_status(self, mock_conn): + + mock_conn.create_private_graph_endpoint.return_value = { + "status": "FAILED", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + with pytest.raises( + NeptunePrivateEndpointCreationFailedError, + match=f"Private endpoint failed to create for graph {GRAPH_ID}", + ): + operator.execute(None) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "_get_graph_endpoint_id") + def test_execute_complete(self, mock_get_endpoint, mock_conn): + mock_get_endpoint.return_value = ENDPOINT_ID + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + mock_conn.create_private_graph_endpoint.return_value = {"vpcId": VPC_ID} + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID + ) + + result = operator.execute_complete( + context=None, event={"status": "success", "value": GRAPH_ID, "vpc_id": VPC_ID} + ) + + # mock_conn.get_private_graph_endpoint.assert_called_once_with( + mock_get_endpoint.assert_called_once_with( + graph_id=GRAPH_ID, + vpc_id=VPC_ID, + ) + assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + + +class TestNeptuneDeletePrivateGraphEndpointOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.delete_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.delete_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_endpoint_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=False, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=True, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("private_graph_endpoint_deleted") + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_endpoint_failed_status(self, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "FAILED", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + with pytest.raises( + NeptunePrivateEndpointDeletionFailedError, + match=f"Failed to delete private endpoint {ENDPOINT_ID}", + ): + operator.execute(None) + + def test_execute_complete_success(self): + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + event = { + "status": "success", + "endpoint_id": ENDPOINT_ID, + } + + operator.execute_complete(None, event) + + # Verify the method completes without error and logs the endpoint_id + + +class TestNeptuneDeleteGraphOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.skip_snapshot is True + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.delete_graph.assert_called_once_with( + graphIdentifier=GRAPH_ID, + skipSnapshot=True, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=False, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.skip_snapshot is False + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.delete_graph.assert_called_once_with( + graphIdentifier=GRAPH_ID, + skipSnapshot=False, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_graph_no_wait(self, mock_get_waiter, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + wait_for_completion=False, + ) + operator.execute(None) + + mock_conn.delete_graph.assert_called_once() + mock_get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_graph_wait_for_completion(self, mock_get_waiter, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + mock_waiter = mock.MagicMock() + mock_get_waiter.return_value = mock_waiter + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + wait_for_completion=True, + ) + operator.execute(None) + + mock_get_waiter.assert_called_once_with("graph_deleted") + mock_waiter.wait.assert_called_once_with( + graphIdentifier=GRAPH_ID, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_graph_resource_not_found(self, mock_conn): + + # Simulate ResourceNotFoundException + error_response = { + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Graph not found", + }, + "ResponseMetadata": { + "HTTPStatusCode": 404, + }, + } + mock_conn.delete_graph.side_effect = ClientError(error_response, "delete_graph") + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + ) + + # Should not raise an exception, just log that graph not found + operator.execute(None) + + mock_conn.delete_graph.assert_called_once_with( + graphIdentifier=GRAPH_ID, + skipSnapshot=True, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_graph_other_client_error(self, mock_conn): + + # Simulate other ClientError + error_response = { + "Error": { + "Code": "ValidationException", + "Message": "Invalid parameter", + }, + "ResponseMetadata": { + "HTTPStatusCode": 400, + }, + } + mock_conn.delete_graph.side_effect = ClientError(error_response, "delete_graph") + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + ) + + # Should raise NeptuneGraphDeletionFailedError for non-ResourceNotFoundException errors + with pytest.raises(NeptuneGraphDeletionFailedError): + operator.execute(None) + + +class TestNeptuneCreateGraphWithImportOperator: + IMPORT_TASK_ID = "import-task-12345" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + ) + + assert operator.graph_name == GRAPH_NAME + assert operator.vector_search_config == {"dimension": 128} + assert operator.source == SOURCE_S3_URI + assert operator.role_arn == ROLE_ARN + assert operator.blank_node_handling is None + assert operator.parquet_type is None + assert operator.format is None + assert operator.min_provisioned_memory is None + assert operator.max_provisioned_memory is None + assert operator.fail_on_error is None + assert operator.public_connectivity is None + assert operator.replica_count is None + assert operator.deletion_protect is None + assert operator.kms_key_id is None + assert operator.tags is None + assert operator.import_options is None + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.create_graph_using_import_task.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"dimension": 128}, + source=SOURCE_S3_URI, + roleArn=ROLE_ARN, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_with_all_optional_params(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + blank_node_handling="convertToIri", + parquet_type="COLUMNAR", + format="csv", + min_provisioned_memory=16, + max_provisioned_memory=32, + fail_on_error=True, + public_connectivity=True, + replica_count=2, + deletion_protection=True, + kms_key_id="test-kms-key", + tags={"env": "test"}, + import_options={"custom-option": "value"}, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.blank_node_handling == "convertToIri" + assert operator.parquet_type == "COLUMNAR" + assert operator.format == "csv" + assert operator.min_provisioned_memory == 16 + assert operator.max_provisioned_memory == 32 + assert operator.fail_on_error is True + assert operator.public_connectivity is True + assert operator.replica_count == 2 + assert operator.deletion_protect is True + assert operator.kms_key_id == "test-kms-key" + assert operator.tags == {"env": "test"} + assert operator.import_options == {"custom-option": "value"} + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + # Verify the call includes all parameters + call_args = mock_conn.create_graph_using_import_task.call_args[1] + assert call_args["graphName"] == GRAPH_NAME + assert call_args["vectorSearchConfiguration"] == {"dimension": 128} + assert call_args["source"] == SOURCE_S3_URI + assert call_args["roleArn"] == ROLE_ARN + assert call_args["format"] == "csv" + assert call_args["minProvisionedMemory"] == 16 + assert call_args["maxProvisionedMemory"] == 32 + assert call_args["failOnError"] is True + assert call_args["replicaCount"] == 2 + assert call_args["publicConnectivity"] is True + assert call_args["deletionProtection"] is True + assert call_args["kmsKeyIdentifier"] == "test-kms-key" + assert call_args["tags"] == {"env": "test"} + # Check import options were merged + assert "neptune-analytics:blank-node-handling" in call_args["importOptions"] + assert call_args["importOptions"]["neptune-analytics:blank-node-handling"] == "convertToIri" + assert "neptune-analytics:parquet-type" in call_args["importOptions"] + assert call_args["importOptions"]["neptune-analytics:parquet-type"] == "COLUMNAR" + assert call_args["importOptions"]["custom-option"] == "value" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_import_options_handling(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + blank_node_handling="convertToIri", + import_options={"another-option": "test"}, + ) + + operator.execute(None) + + call_args = mock_conn.create_graph_using_import_task.call_args[1] + # Verify import options were properly merged + assert call_args["importOptions"]["neptune-analytics:blank-node-handling"] == "convertToIri" + assert call_args["importOptions"]["another-option"] == "test" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph_with_import_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + wait_for_completion=False, + ) + result = operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + assert result == {"graph_id": GRAPH_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph_with_import_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + wait_for_completion=True, + ) + result = operator.execute(None) + + # Should wait for both graph_available and import_task_successful + assert mock_hook_get_waiter.call_count == 2 + mock_hook_get_waiter.assert_any_call("graph_available") + mock_hook_get_waiter.assert_any_call("import_task_successful") + assert result == {"graph_id": GRAPH_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_import_options_none_values_filtered(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + # Test that None values in blank_node_handling and parquet_type are filtered out + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + blank_node_handling=None, + parquet_type=None, + ) + + operator.execute(None) + + call_args = mock_conn.create_graph_using_import_task.call_args[1] + # importOptions should not be in call_args if all values are None + assert "importOptions" not in call_args + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_defer_wait_for_task(self, mock_conn): + """Test that defer_wait_for_task defers with the import task trigger.""" + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + waiter_delay=30, + waiter_max_attempts=60, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.defer_wait_for_task( + import_task_id=self.IMPORT_TASK_ID, + context=None, + event={"status": "success"}, + ) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneImportTaskCompleteTrigger) + assert exc_info.value.method_name == "execute_complete" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_deferrable_defers_with_graph_available_trigger(self, mock_conn): + """Test that execute defers with graph_available trigger and passes import_task_id.""" + + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneGraphAvailableTrigger) + assert exc_info.value.method_name == "defer_wait_for_task" + assert exc_info.value.kwargs == {"import_task_id": self.IMPORT_TASK_ID} + + +TASK_ID = "import-task-id-12345" + + +class TestNeptuneStartImportTaskOperator: + def test_template_fields(self): + fields = NeptuneStartImportTaskOperator.template_fields + assert "graph_identifier" in fields + assert "role_arn" in fields + assert "source" in fields + assert "import_options" in fields + + def test_template_fields_renderers(self): + assert NeptuneStartImportTaskOperator.template_fields_renderers == {"import_options": "json"} + + def test_operator_extra_links(self): + assert len(NeptuneStartImportTaskOperator.operator_extra_links) == 1 + assert isinstance(NeptuneStartImportTaskOperator.operator_extra_links[0], NeptuneImportTaskLink) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn, mock_persist): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.role_arn == ROLE_ARN + assert operator.source == SOURCE_S3_URI + assert operator.blank_node_handling is None + assert operator.fail_on_error is True + assert operator.format is None + assert operator.import_options is None + assert operator.parquet_type == "COLUMNAR" + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.start_import_task.assert_called_once_with( + graphIdentifier=GRAPH_ID, + roleArn=ROLE_ARN, + source=SOURCE_S3_URI, + failOnError=True, + parquetType="COLUMNAR", + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn, mock_persist): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + blank_node_handling=None, + fail_on_error=False, + format="CSV", + import_options={"neptune.csv.allowEmptyStrings": True}, + parquet_type=None, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.blank_node_handling is None + assert operator.fail_on_error is False + assert operator.format == "CSV" + assert operator.import_options == {"neptune.csv.allowEmptyStrings": True} + assert operator.parquet_type is None + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.start_import_task.assert_called_once_with( + graphIdentifier=GRAPH_ID, + roleArn=ROLE_ARN, + source=SOURCE_S3_URI, + failOnError=False, + format="CSV", + importOptions={"neptune.csv.allowEmptyStrings": True}, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_start_import_no_wait(self, mock_get_waiter, mock_conn, mock_persist): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + wait_for_completion=False, + ) + result = operator.execute(None) + + mock_get_waiter.assert_not_called() + assert result == {"import_task_id": TASK_ID, "graph_id": GRAPH_ID} + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_start_import_wait_for_completion(self, mock_get_waiter, mock_conn, mock_persist): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + wait_for_completion=True, + ) + result = operator.execute(None) + + mock_get_waiter.assert_called_once_with("import_task_successful") + assert result == {"import_task_id": TASK_ID, "graph_id": GRAPH_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_persist_called_with_correct_args(self, mock_conn): + """Test that NeptuneImportTaskLink.persist is called with the correct arguments.""" + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + wait_for_completion=False, + ) + + mock_context = mock.MagicMock() + with mock.patch( + "airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist" + ) as mock_persist: + operator.execute(mock_context) + + mock_persist.assert_called_once_with( + context=mock_context, + operator=operator, + region_name=mock.ANY, + aws_partition=mock.ANY, + import_task_id=TASK_ID, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_deferrable_defers_with_import_task_trigger(self, mock_conn, mock_persist): + """Test that deferrable mode defers with NeptuneImportTaskCompleteTrigger.""" + + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneImportTaskCompleteTrigger) + assert exc_info.value.method_name == "execute_complete" + + def test_execute_complete_success(self): + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + ) + + event = {"status": "success", "import_task_id": TASK_ID} + result = operator.execute_complete(None, event) + + assert result == {"graph_id": GRAPH_ID, "import_task_id": TASK_ID} + + +class TestNeptuneCancelImportTaskOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + import_task_id=TASK_ID, + ) + + assert operator.import_task_id == TASK_ID + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.cancel_import_task.assert_called_once_with(taskIdentifier=TASK_ID) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + import_task_id=TASK_ID, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.import_task_id == TASK_ID + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_cancel_no_wait(self, mock_get_waiter, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + import_task_id=TASK_ID, + wait_for_completion=False, + ) + result = operator.execute(None) + + mock_get_waiter.assert_not_called() + assert result == {"import_task_id": TASK_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_cancel_wait_for_completion(self, mock_get_waiter, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + import_task_id=TASK_ID, + wait_for_completion=True, + ) + result = operator.execute(None) + + mock_get_waiter.assert_called_once_with("import_task_cancelled") + assert result == {"import_task_id": TASK_ID} + + def test_execute_complete_success(self): + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + import_task_id=TASK_ID, + ) + + event = {"status": "success", "value": TASK_ID} + result = operator.execute_complete(None, event) + + assert result == {"import_task_id": TASK_ID} diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py new file mode 100644 index 0000000000000..9232a9582322b --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -0,0 +1,343 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.triggers.neptune_analytics import ( + NeptuneGraphAvailableTrigger, + NeptuneGraphDeletedTrigger, + NeptuneGraphPrivateEndpointAvailableTrigger, + NeptuneGraphPrivateEndpointDeletedTrigger, + NeptuneImportTaskCancelledTrigger, + NeptuneImportTaskCompleteTrigger, +) +from airflow.triggers.base import TriggerEvent + +GRAPH_ID = "test-graph" +VPC_ID = "test-vpc" +ENDPOINT_ID = "test-endpoint" +TASK_ID = "test-task-id" + + +class TestNeptuneGraphAvailableTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneGraphAvailableTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphAvailableTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "AVAILABLE" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "graph_id": GRAPH_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="graph_available", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp.payload["status"] == "error" + assert resp.payload["graph_id"] == GRAPH_ID + assert "Failed to create Neptune graph" in resp.payload["message"] + + +class TestNeptuneGraphPrivateEndpointAvailableTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneGraphPrivateEndpointAvailableTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphPrivateEndpointAvailableTrigger(graph_id=GRAPH_ID, vpc_id=VPC_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphPrivateEndpointAvailableTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + assert "vpc_id" in kwargs + assert kwargs["vpc_id"] == VPC_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "AVAILABLE" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphPrivateEndpointAvailableTrigger(graph_id=GRAPH_ID, vpc_id=VPC_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "graph_id": GRAPH_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="private_graph_endpoint_available", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID, "vpcId": VPC_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphPrivateEndpointAvailableTrigger(graph_id=GRAPH_ID, vpc_id=VPC_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp.payload["status"] == "error" + assert resp.payload["graph_id"] == GRAPH_ID + assert "Failed to create Neptune graph endpoint" in resp.payload["message"] + + +class TestNeptuneGraphPrivateEndpointDeletedTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneGraphPrivateEndpointDeletedTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphPrivateEndpointDeletedTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphPrivateEndpointDeletedTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + assert "vpc_id" in kwargs + assert kwargs["vpc_id"] == VPC_ID + assert "endpoint_id" in kwargs + assert kwargs["endpoint_id"] == ENDPOINT_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "DELETED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphPrivateEndpointDeletedTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "endpoint_id": ENDPOINT_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="private_graph_endpoint_deleted", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID, "vpcId": VPC_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphPrivateEndpointDeletedTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp.payload["status"] == "error" + assert resp.payload["endpoint_id"] == ENDPOINT_ID + assert "Failed to delete Neptune graph endpoint" in resp.payload["message"] + + +class TestNeptuneImportTaskCompleteTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneImportTaskCompleteTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneImportTaskCompleteTrigger(import_task_id=TASK_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneImportTaskCompleteTrigger" + ) + assert "import_task_id" in kwargs + assert kwargs["import_task_id"] == TASK_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "COMPLETED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneImportTaskCompleteTrigger(import_task_id=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "import_task_id": TASK_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="import_task_successful", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "taskIdentifier": TASK_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneImportTaskCompleteTrigger(import_task_id=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp.payload["status"] == "error" + assert resp.payload["import_task_id"] == TASK_ID + assert "Import task failed" in resp.payload["message"] + + +class TestNeptuneImportTaskCancelledTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneImportTaskCancelledTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneImportTaskCancelledTrigger(task_identifier=TASK_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneImportTaskCancelledTrigger" + ) + assert "task_identifier" in kwargs + assert kwargs["task_identifier"] == TASK_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "CANCELLED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneImportTaskCancelledTrigger(task_identifier=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "import_task_id": TASK_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="import_task_cancelled", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "taskIdentifier": TASK_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneImportTaskCancelledTrigger(task_identifier=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp.payload["status"] == "error" + assert resp.payload["import_task_id"] == TASK_ID + assert "Import task cancellation failed" in resp.payload["message"] + + +class TestNeptuneGraphDeletedTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneGraphDeletedTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphDeletedTrigger(graph_id=GRAPH_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphDeletedTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "DELETED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphDeletedTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "graph_id": GRAPH_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="graph_deleted", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphDeletedTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp.payload["status"] == "error" + assert resp.payload["graph_id"] == GRAPH_ID + assert "Failed to delete Neptune graph" in resp.payload["message"]