From 30750787761013f90cea6f9f18797b965be8418a Mon Sep 17 00:00:00 2001 From: mse139 Date: Wed, 4 Feb 2026 06:11:12 -0500 Subject: [PATCH 01/28] neptune analytics initial --- .../amazon/aws/hooks/neptune_analytics.py | 37 ++++++ .../amazon/aws/operators/neptune_analytics.py | 113 ++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py create mode 100644 providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py new file mode 100644 index 0000000000000..1614808b587d4 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class NeptuneAnalyticsHook(AwsBaseHook): + """ + Interact with Amazon Neptune Analytics. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs): + kwargs["client_type"] = "neptune_graph" + super().__init__(*args, **kwargs) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py new file mode 100644 index 0000000000000..8c4a51adf5cf7 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.neptune import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.providers.common.compat.sdk import conf + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class CreateNeptuneGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Creates an empty Amazon Neptune Graph database. + + Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneStartDbClusterOperator` + + :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster to be started. + :param wait_for_completion: Whether to wait for the cluster to start. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the cluster to start. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune cluster id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields() + + def __init__( + self, + graph_name: str, + vector_search_config: dict, + replica_count: int, + provisioned_memory: int, + deletion_protection: bool = False, + kms_key_id: str | None = None, + tags: dict | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_name = graph_name + self.vector_search_config = vector_search_config + self.replica_count = replica_count + self.provisioned_memory = provisioned_memory + self.deletion_protect = deletion_protection + self.kms_key = kms_key_id + self.tags = tags + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> dict: + self.log.info("Creating graph %s", self.graph_name) + response = self.hook.create_graph( + graphName=self.graph_name, + vectorSearchConfiguration=self.vector_search_config, + replicaCount=self.replica_count, + provisionedMemory=self.provisioned_memory, + deletionProtection=self.deletion_protect, + kmsKeyIdentifier=self.kms_key, + tags=self.tags, + ) + + self.log.info("Graph %s in status %s", self.graph_name, response.get("status", "Unknown")) + + if self.deferrable: + pass + # TODO add waiter + if self.wait_for_completion: + # TODO wait for good status + pass + + # TODO return - maybe store ID, ARN, Name From c28fefa6b15f13545d1ae15280b8c09e84cedec3 Mon Sep 17 00:00:00 2001 From: mse139 Date: Wed, 18 Feb 2026 19:34:41 -0500 Subject: [PATCH 02/28] Added NeptuneCreateGraphOperator and supporting files/classes --- providers/amazon/provider.yaml | 4 + .../amazon/aws/hooks/neptune_analytics.py | 2 +- .../amazon/aws/operators/neptune_analytics.py | 93 +++++++++---- .../amazon/aws/triggers/neptune_analytics.py | 67 ++++++++++ .../providers/amazon/get_provider_info.py | 15 ++- .../aws/hooks/test_neptune_analytics.py | 46 +++++++ .../aws/operators/test_neptune_analytics.py | 126 ++++++++++++++++++ .../aws/triggers/test_neptune_analytics.py | 77 +++++++++++ 8 files changed, 403 insertions(+), 27 deletions(-) create mode 100644 providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py create mode 100644 providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py create mode 100644 providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py create mode 100644 providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index f396e2184e574..aa74974d6d834 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -520,6 +520,7 @@ operators: - integration-name: Amazon Neptune python-modules: - airflow.providers.amazon.aws.operators.neptune + - airflow.providers.amazon.aws.operators.neptune_analytics - integration-name: Amazon S3 Vectors python-modules: - airflow.providers.amazon.aws.operators.s3_vectors @@ -784,6 +785,8 @@ hooks: - integration-name: Amazon Neptune python-modules: - airflow.providers.amazon.aws.hooks.neptune + - airflow.providers.amazon.aws.hooks.neptune_analytics + bundles: - integration-name: Amazon Simple Storage Service (S3) @@ -866,6 +869,7 @@ triggers: - integration-name: Amazon Neptune python-modules: - airflow.providers.amazon.aws.triggers.neptune + - airflow.providers.amazon.aws.triggers.neptune_analytics - integration-name: AWS Database Migration Service python-modules: - airflow.providers.amazon.aws.triggers.dms diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py index 1614808b587d4..252878fa11b4b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py @@ -33,5 +33,5 @@ class NeptuneAnalyticsHook(AwsBaseHook): """ def __init__(self, *args, **kwargs): - kwargs["client_type"] = "neptune_graph" + kwargs["client_type"] = "neptune-graph" super().__init__(*args, **kwargs) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 8c4a51adf5cf7..7cca99683d2c7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -18,10 +18,11 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from airflow.providers.amazon.aws.hooks.neptune import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneGraphAvailableTrigger from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.common.compat.sdk import conf @@ -29,7 +30,7 @@ from airflow.sdk import Context -class CreateNeptuneGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): +class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): """ Creates an empty Amazon Neptune Graph database. @@ -37,11 +38,19 @@ class CreateNeptuneGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:NeptuneStartDbClusterOperator` + :ref:`howto/operator:NeptuneCreateGraphOperator` - :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster to be started. - :param wait_for_completion: Whether to wait for the cluster to start. (default: True) - :param deferrable: If True, the operator will wait asynchronously for the cluster to start. + :param graph_name: Name of Neptune graph to create + :param vector_search_config: Specifies the number of dimensions for vector embeddings that will be loaded into the graph. + :param provisioned_memory: The provisioned memory-optimized Neptune Capacity Units (m-NCUs) to use for the graph. + :param public_connectivity: Specifies whether or not the graph can be reachable over the internet. + :param replica_count: The number of replicas in other AZs. + :param deletion_protection: Indicates whether or not to enable deletion protection on the graph. + The graph can't be deleted when deletion protection is enabled. + :param kms_key_id: Specifies a KMS key to use to encrypt data in the new graph. + :param tags Specifies metadata tags to add to the graph. + :param wait_for_completion: Whether to wait for the graph to start. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the graph to start. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) :param waiter_delay: Time in seconds to wait between status checks. @@ -55,7 +64,7 @@ class CreateNeptuneGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html - :return: dictionary with Neptune cluster id + :return: dictionary with Neptune graph id """ aws_hook_class = NeptuneAnalyticsHook @@ -65,8 +74,9 @@ def __init__( self, graph_name: str, vector_search_config: dict, - replica_count: int, provisioned_memory: int, + public_connectivity: bool | None = None, + replica_count: int | None = None, deletion_protection: bool = False, kms_key_id: str | None = None, tags: dict | None = None, @@ -81,6 +91,7 @@ def __init__( self.vector_search_config = vector_search_config self.replica_count = replica_count self.provisioned_memory = provisioned_memory + self.public_connectivity = public_connectivity self.deletion_protect = deletion_protection self.kms_key = kms_key_id self.tags = tags @@ -91,23 +102,59 @@ def __init__( def execute(self, context: Context) -> dict: self.log.info("Creating graph %s", self.graph_name) - response = self.hook.create_graph( - graphName=self.graph_name, - vectorSearchConfiguration=self.vector_search_config, - replicaCount=self.replica_count, - provisionedMemory=self.provisioned_memory, - deletionProtection=self.deletion_protect, - kmsKeyIdentifier=self.kms_key, - tags=self.tags, - ) + + # TODO perform check + create_params = { + "graphName": self.graph_name, + "vectorSearchConfiguration": self.vector_search_config, + "provisionedMemory": self.provisioned_memory, + **{ + k: v + for k, v in { + "replicaCount": self.replica_count, + "publicConnectivity": self.public_connectivity, + "deletionProtection": self.deletion_protect, + "kmsKeyIdentifier": self.kms_key, + "tags": self.tags, + }.items() + if v is not None + }, + } + + response = self.hook.conn.create_graph(**create_params) self.log.info("Graph %s in status %s", self.graph_name, response.get("status", "Unknown")) + self.graph_id = response.get("id", None) + + # TODO build extra link to console if self.deferrable: - pass - # TODO add waiter + self.log.info("Deferring until graph %s is available", self.graph_id) + self.defer( + trigger=NeptuneGraphAvailableTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: - # TODO wait for good status - pass + self.log.info("Waiting until graph %s is available", self.graph_id) + self.hook.get_waiter("graph_available").wait( + graphIdentifier=self.graph_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"graph_id": self.graph_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]: + graph_id = "" + + if event: + graph_id = event.get("graph_id", "Unknown") + + self.log.info("Neptune graph % complete", graph_id) - # TODO return - maybe store ID, ARN, Name + return {"graph_id": graph_id} diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py new file mode 100644 index 0000000000000..10f2804f5ba2b --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + +class NeptuneGraphAvailableTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune graph is available. + + :param graph_id: Graph ID to poll. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id}, + waiter_name="graph_available", + waiter_args={"graphIdentifier": graph_id}, + failure_message="Failed to create Neptune graph", + status_message="Status of Neptune graph is", + status_queries=["status"], + return_key="graph_id", + return_value=graph_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 99721f3c481b9..c4afec9768b88 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -529,7 +529,10 @@ def get_provider_info(): }, { "integration-name": "Amazon Neptune", - "python-modules": ["airflow.providers.amazon.aws.operators.neptune"], + "python-modules": [ + "airflow.providers.amazon.aws.operators.neptune", + "airflow.providers.amazon.aws.operators.neptune_analytics", + ], }, { "integration-name": "Amazon S3 Vectors", @@ -874,7 +877,10 @@ def get_provider_info(): }, { "integration-name": "Amazon Neptune", - "python-modules": ["airflow.providers.amazon.aws.hooks.neptune"], + "python-modules": [ + "airflow.providers.amazon.aws.hooks.neptune", + "airflow.providers.amazon.aws.hooks.neptune_analytics", + ], }, ], "bundles": [ @@ -987,7 +993,10 @@ def get_provider_info(): }, { "integration-name": "Amazon Neptune", - "python-modules": ["airflow.providers.amazon.aws.triggers.neptune"], + "python-modules": [ + "airflow.providers.amazon.aws.triggers.neptune", + "airflow.providers.amazon.aws.triggers.neptune_analytics", + ], }, { "integration-name": "AWS Database Migration Service", diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py new file mode 100644 index 0000000000000..23cf28044135d --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Generator +from unittest import mock + +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook + + +@pytest.fixture +def neptune_hook() -> Generator[NeptuneAnalyticsHook, None, None]: + """Returns a NeptuneAnalyticsHook mocked with moto""" + with mock_aws(): + yield NeptuneAnalyticsHook(aws_conn_id="aws_default") + + +class TestNeptuneAnalyticsHook: + graph_id = "abc123" + + def test_get_conn_returns_a_boto3_connection(self): + hook = NeptuneAnalyticsHook(aws_conn_id="aws_default") + assert hook.get_conn() is not None + + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_wait_for_graph_availability(self, mock_get_waiter, neptune_hook: NeptuneAnalyticsHook): + waiter = mock_get_waiter("graph_available") + assert waiter is not None diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py new file mode 100644 index 0000000000000..debbcd6918451 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Generator +from unittest import mock + +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.operators.neptune_analytics import NeptuneCreateGraphOperator + +GRAPH_NAME = "test_graph" + + +@pytest.fixture +def hook() -> Generator[NeptuneAnalyticsHook, None, None]: + with mock_aws(): + yield NeptuneAnalyticsHook(aws_conn_id="aws_default") + + +class TestNeptuneCreateGraphOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + ) + + assert operator.public_connectivity is None + assert operator.replica_count is None + assert operator.deletion_protect is False + assert operator.kms_key is None + assert operator.tags is None + + operator.execute(None) + + mock_conn.create_graph.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"test": 123}, + provisionedMemory=16, + deletionProtection=False, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + public_connectivity=True, + replica_count=3, + kms_key_id="test-key", + tags={"key1": "test"}, + deletion_protection=True, + ) + + assert operator.public_connectivity is True + assert operator.replica_count == 3 + assert operator.deletion_protect is True + assert operator.kms_key == "test-key" + assert operator.tags == {"key1": "test"} + + operator.execute(None) + + mock_conn.create_graph.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"test": 123}, + replicaCount=3, + publicConnectivity=True, + provisionedMemory=16, + deletionProtection=True, + kmsKeyIdentifier="test-key", + tags={"key1": "test"}, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph(self, mock_hook_get_waiter, mock_conn): + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + provisioned_memory=16, + vector_search_config={"test": 123}, + wait_for_completion=False, + ) + resp = operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + assert "graph_id" in resp + assert resp["graph_id"] is not None + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + provisioned_memory=16, + vector_search_config={"test": 123}, + wait_for_completion=True, + ) + resp = operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("graph_available") + assert "graph_id" in resp + assert resp["graph_id"] is not None diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py new file mode 100644 index 0000000000000..a3480531c8046 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.triggers.neptune_analytics import ( + NeptuneGraphAvailableTrigger, +) +from airflow.providers.common.compat.sdk import AirflowException +from airflow.triggers.base import TriggerEvent + +GRAPH_ID = "test-graph" + + +class TestNeptuneGraphAvailableTrigger: + def test_serialization(self): + """ + Asserts that the TaskStateTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphAvailableTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "AVAILABLE" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "graph_id": GRAPH_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="graph_available", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) + + with pytest.raises(AirflowException): + await trigger.run().asend(None) From 820f68979f5dfd0d5fb0bf0d270896b6df1db2fb Mon Sep 17 00:00:00 2001 From: mse139 Date: Wed, 4 Mar 2026 06:15:11 -0500 Subject: [PATCH 03/28] Added private endpoint operators --- .../amazon/aws/operators/neptune_analytics.py | 226 ++++++++++++- .../amazon/aws/triggers/neptune_analytics.py | 90 ++++++ .../aws/operators/test_neptune_analytics.py | 296 +++++++++++++++++- .../aws/triggers/test_neptune_analytics.py | 116 +++++++ 4 files changed, 726 insertions(+), 2 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 7cca99683d2c7..616ddadd94115 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -20,9 +20,14 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator -from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneGraphAvailableTrigger +from airflow.providers.amazon.aws.triggers.neptune_analytics import ( + NeptuneGraphAvailableTrigger, + NeptuneGraphPrivateEndpointAvailableTrigger, + NeptuneGraphPrivateEndpointDeletedTrigger, +) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.common.compat.sdk import conf @@ -158,3 +163,222 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None self.log.info("Neptune graph % complete", graph_id) return {"graph_id": graph_id} + + +class NeptuneCreatePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Creates a Neptune Graph private endpoint. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCreateGraphOperator` + + :param graph_identifier: Neptune Graph id + :param vpc_id: VPC to create endpoint in + :param subnet_ids: Subnets in which private graph endpoint ENIs are created + :param vpc_security_group_ids: Security groups to be attached to the private graph endpoint + + :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields( + "graph_identifier", "vpc_id", "subnet_ids", "vpc_security_group_ids" + ) + + def __init__( + self, + graph_identifier: str, + vpc_id: str | None = None, + subnet_ids: list[str] | None = None, + vpc_security_group_ids: list[str] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_id = graph_identifier + self.vpc_id = vpc_id + self.subnet_ids = subnet_ids + self.vpc_security_group_ids = vpc_security_group_ids + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> dict: + self.log.info("Creating private endpoint for graph %s", self.graph_id) + + create_params = { + "graphIdentifier": self.graph_id, + **{ + k: v + for k, v in { + "vpcId": self.vpc_id, + "subnetIds": self.subnet_ids, + "vpcSecurityGroupIds": self.vpc_security_group_ids, + }.items() + if v is not None + }, + } + + # create the endpoint + + result = self.hook.conn.create_private_graph_endpoint(**create_params) + status = result.get("status", "Unknown") + endpoint_id = result.get("vpcEndpointId", "Unknown") + + self.log.info("Status of endpoint %s: %s", endpoint_id, status) + + if status in ["FAILED"]: + raise AirflowException(f"Private endpoint failed to create for graph {self.graph_id}") + + # if VPC not provided, use the one that is returned. Required for the waiter + self.vpc_id = result.get("vpcId", self.vpc_id) + + # TODO extra link to console + + if self.deferrable: + self.log.info("Deferring until endpoint %s is available", endpoint_id) + self.defer( + trigger=NeptuneGraphPrivateEndpointAvailableTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + vpc_id=self.vpc_id, + endpoint_id=endpoint_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + + # TODO add test + if self.wait_for_completion: + self.log.info("Waiting until endpoint %s is available", endpoint_id) + self.hook.get_waiter("private_graph_endpoint_available").wait( + graphIdentifier=self.graph_id, + vpcId=self.vpc_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"vpc_endpoint_id": endpoint_id, "graph_id": self.graph_id, "vpc_id": self.vpc_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]: + vpc_endpoint_id = "" + + if event and event.get("status") == "success": + vpc_endpoint_id = event.get("endpoint_id") + + return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_id, "vpc_id": self.vpc_id} + + +class NeptuneDeletePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Deletes a Neptune Graph private endpoint. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneDeletePrivateGraphEndpointOperator` + + :param graph_identifier: Neptune Graph id + :param vpc_id: VPC to create endpoint in + :param subnet_ids: Subnets in which private graph endpoint ENIs are created + :param vpc_security_group_ids: Security groups to be attached to the private graph endpoint + + :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields("graph_identifier", "vpc_id") + + def __init__( + self, + graph_identifier: str, + vpc_id: str, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_id = graph_identifier + self.vpc_id = vpc_id + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> None: + self.log.info("Deleting private endpoint for graph %s", self.graph_id) + + result = self.hook.conn.delete_private_graph_endpoint( + graphIdentifier=self.graph_id, vpcId=self.vpc_id + ) + + status = result.get("status") + endpoint_id = result.get("vpcEndpointId") + + if status == "FAILED": + raise AirflowException(f"Failed to delete private endpoint {endpoint_id}") + + if self.deferrable: + self.log.info("Deferring until endpoint %s is deleted", endpoint_id) + self.defer( + trigger=NeptuneGraphPrivateEndpointDeletedTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + vpc_id=self.vpc_id, + endpoint_id=endpoint_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: + self.log.info("Waiting until endpoint %s is deleted", endpoint_id) + self.hook.get_waiter("private_graph_endpoint_deleted").wait( + graphIdentifier=self.graph_id, + vpcId=self.vpc_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + vpc_endpoint_id = "" + + if event and event.get("status") == "success": + vpc_endpoint_id = event.get("endpoint_id") + + self.log.info("Endpoint id %s deleted", vpc_endpoint_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py index 10f2804f5ba2b..f8c0b716044b2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -65,3 +65,93 @@ def hook(self) -> AwsGenericHook: verify=self.verify, config=self.botocore_config, ) + + +class NeptuneGraphPrivateEndpointAvailableTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune Graph private endpoint is available. + + :param graph_id: Graph Id waiting for the endpoint + :param vpc_id: VPC id where endpoint is creating + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + vpc_id: str, + endpoint_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id, "vpc_id": vpc_id, "endpoint_id": endpoint_id}, + waiter_name="private_graph_endpoint_available", + waiter_args={"graphIdentifier": graph_id, "vpcId": vpc_id}, + failure_message="Failed to create Neptune graph endpoint", + status_message="Status of Neptune graph endpoint is", + status_queries=["status"], + return_key="endpoint_id", + return_value=endpoint_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class NeptuneGraphPrivateEndpointDeletedTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune Graph private endpoint is deleted. + + :param graph_id: Graph Id of the endpoint + :param vpc_id: VPC id where endpoint resides + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + vpc_id: str, + endpoint_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id, "vpc_id": vpc_id, "endpoint_id": endpoint_id}, + waiter_name="private_graph_endpoint_deleted", + waiter_args={"graphIdentifier": graph_id, "vpcId": vpc_id}, + failure_message="Failed to delete Neptune graph endpoint", + status_message="Status of Neptune graph endpoint is", + status_queries=["status"], + return_key="endpoint_id", + return_value=endpoint_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index debbcd6918451..57471cf66889b 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -24,9 +24,18 @@ from moto import mock_aws from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook -from airflow.providers.amazon.aws.operators.neptune_analytics import NeptuneCreateGraphOperator +from airflow.providers.amazon.aws.operators.neptune_analytics import ( + NeptuneCreateGraphOperator, + NeptuneCreatePrivateGraphEndpointOperator, + NeptuneDeletePrivateGraphEndpointOperator, +) GRAPH_NAME = "test_graph" +GRAPH_ID = "test-graph-id" +VPC_ID = "vpc-12345" +SUBNET_IDS = ["subnet-1", "subnet-2"] +SECURITY_GROUP_IDS = ["sg-1", "sg-2"] +ENDPOINT_ID = "vpce-12345" @pytest.fixture @@ -124,3 +133,288 @@ def test_create_graph_wait_for_completion(self, mock_hook_get_waiter, mock_conn) mock_hook_get_waiter.assert_called_once_with("graph_available") assert "graph_id" in resp assert resp["graph_id"] is not None + + +class TestNeptuneCreatePrivateGraphEndpointOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.vpc_id is None + assert operator.subnet_ids is None + assert operator.vpc_security_group_ids is None + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + result = operator.execute(None) + + mock_conn.create_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + ) + + assert result is not None + assert result["vpc_endpoint_id"] == ENDPOINT_ID + assert result["graph_id"] == GRAPH_ID + assert result["vpc_id"] == VPC_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + subnet_ids=SUBNET_IDS, + vpc_security_group_ids=SECURITY_GROUP_IDS, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.subnet_ids == SUBNET_IDS + assert operator.vpc_security_group_ids == SECURITY_GROUP_IDS + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.create_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + subnetIds=SUBNET_IDS, + vpcSecurityGroupIds=SECURITY_GROUP_IDS, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_endpoint_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=False, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=True, + ) + operator.execute(None) + + # Note: The operator currently has 'pass' for wait_for_completion + # This test documents the current behavior + # When wait_for_completion is implemented, this test should verify the waiter is called + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_create_endpoint_failed_status(self, mock_conn): + from airflow.providers.common.compat.sdk import AirflowException + + mock_conn.create_private_graph_endpoint.return_value = { + "status": "FAILED", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + with pytest.raises(AirflowException, match=f"Private endpoint failed to create for graph {GRAPH_ID}"): + operator.execute(None) + + def test_execute_complete_success(self): + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID + ) + + event = { + "status": "success", + "endpoint_id": ENDPOINT_ID, + } + + result = operator.execute_complete(None, event) + + assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + + def test_execute_complete_failure_status(self): + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID + ) + + event = { + "status": "failure", + "endpoint_id": ENDPOINT_ID, + } + + result = operator.execute_complete(None, event) + + assert result == {"vpc_endpoint_id": "", "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + + +class TestNeptuneDeletePrivateGraphEndpointOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.delete_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.delete_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_endpoint_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=False, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=True, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("private_graph_endpoint_deleted") + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_endpoint_failed_status(self, mock_conn): + from airflow.providers.common.compat.sdk import AirflowException + + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "FAILED", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + with pytest.raises(AirflowException, match=f"Failed to delete private endpoint {ENDPOINT_ID}"): + operator.execute(None) + + def test_execute_complete_success(self): + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + event = { + "status": "success", + "endpoint_id": ENDPOINT_ID, + } + + operator.execute_complete(None, event) + + # Verify the method completes without error and logs the endpoint_id diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index a3480531c8046..1637b0a15e8da 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -24,11 +24,15 @@ from airflow.providers.amazon.aws.triggers.neptune_analytics import ( NeptuneGraphAvailableTrigger, + NeptuneGraphPrivateEndpointAvailableTrigger, + NeptuneGraphPrivateEndpointDeletedTrigger, ) from airflow.providers.common.compat.sdk import AirflowException from airflow.triggers.base import TriggerEvent GRAPH_ID = "test-graph" +VPC_ID = "test-vpc" +ENDPOINT_ID = "test-endpoint" class TestNeptuneGraphAvailableTrigger: @@ -75,3 +79,115 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): with pytest.raises(AirflowException): await trigger.run().asend(None) + + +class TestNeptuneGraphPrivateEndpointAvailableTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneGraphPrivateEndpointAvailableTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphPrivateEndpointAvailableTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphPrivateEndpointAvailableTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + assert "vpc_id" in kwargs + assert kwargs["vpc_id"] == VPC_ID + assert "endpoint_id" in kwargs + assert kwargs["endpoint_id"] == ENDPOINT_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "AVAILABLE" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphPrivateEndpointAvailableTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "endpoint_id": ENDPOINT_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="private_graph_endpoint_available", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID, "vpcId": VPC_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphPrivateEndpointAvailableTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + + with pytest.raises(AirflowException): + await trigger.run().asend(None) + + +class TestNeptuneGraphPrivateEndpointDeletedTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneGraphPrivateEndpointDeletedTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphPrivateEndpointDeletedTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphPrivateEndpointDeletedTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + assert "vpc_id" in kwargs + assert kwargs["vpc_id"] == VPC_ID + assert "endpoint_id" in kwargs + assert kwargs["endpoint_id"] == ENDPOINT_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "DELETED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphPrivateEndpointDeletedTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "endpoint_id": ENDPOINT_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="private_graph_endpoint_deleted", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID, "vpcId": VPC_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphPrivateEndpointDeletedTrigger( + graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID + ) + + with pytest.raises(AirflowException): + await trigger.run().asend(None) From ebdbbbc11ffe5063a9eff7bc759324bb2c5c9b28 Mon Sep 17 00:00:00 2001 From: mse139 Date: Thu, 5 Mar 2026 06:05:40 -0500 Subject: [PATCH 04/28] fixed assignment error --- .../airflow/providers/amazon/aws/operators/neptune_analytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 616ddadd94115..175891c9589b5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -284,7 +284,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None vpc_endpoint_id = "" if event and event.get("status") == "success": - vpc_endpoint_id = event.get("endpoint_id") + vpc_endpoint_id = event.get("endpoint_id", "") return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_id, "vpc_id": self.vpc_id} From 71aec4cde7d0f185557fa02c934889e8bb9ef042 Mon Sep 17 00:00:00 2001 From: mse139 Date: Thu, 5 Mar 2026 06:15:36 -0500 Subject: [PATCH 05/28] Fixed type error --- .../airflow/providers/amazon/aws/operators/neptune_analytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 175891c9589b5..aa769e1af66ea 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -280,7 +280,7 @@ def execute(self, context: Context) -> dict: return {"vpc_endpoint_id": endpoint_id, "graph_id": self.graph_id, "vpc_id": self.vpc_id} - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: vpc_endpoint_id = "" if event and event.get("status") == "success": From d63f95488a8529ca1a25583ebd06945273e8a9be Mon Sep 17 00:00:00 2001 From: mse139 Date: Thu, 5 Mar 2026 06:22:27 -0500 Subject: [PATCH 06/28] fixed another return type error --- .../airflow/providers/amazon/aws/operators/neptune_analytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index aa769e1af66ea..564b5601001ce 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -154,7 +154,7 @@ def execute(self, context: Context) -> dict: return {"graph_id": self.graph_id} - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: graph_id = "" if event: From 0876b77f039da00c337dd4971ae07631b84e800b Mon Sep 17 00:00:00 2001 From: mse139 Date: Thu, 5 Mar 2026 17:24:05 -0500 Subject: [PATCH 07/28] Fixed assignment error in delete endpoint operator --- .../airflow/providers/amazon/aws/operators/neptune_analytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 564b5601001ce..b025650c05eaa 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -379,6 +379,6 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None vpc_endpoint_id = "" if event and event.get("status") == "success": - vpc_endpoint_id = event.get("endpoint_id") + vpc_endpoint_id = event.get("endpoint_id", "Unknown") self.log.info("Endpoint id %s deleted", vpc_endpoint_id) From d1405551c98a8650d892e68c09066c7bfad3ac6b Mon Sep 17 00:00:00 2001 From: mse139 Date: Mon, 16 Mar 2026 06:18:22 -0400 Subject: [PATCH 08/28] Added NeptuneDeleteGraphOperator and NeptuneStartImportTaskOperator --- .../amazon/aws/operators/neptune_analytics.py | 400 +++++++++++++- .../amazon/aws/triggers/neptune_analytics.py | 85 +++ .../aws/operators/test_neptune_analytics.py | 518 ++++++++++++++++++ .../aws/triggers/test_neptune_analytics.py | 48 ++ 4 files changed, 1041 insertions(+), 10 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index b025650c05eaa..a62e19175d34f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -20,13 +20,17 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any +from botocore.exceptions import ClientError + from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.neptune_analytics import ( NeptuneGraphAvailableTrigger, + NeptuneGraphDeletedTrigger, NeptuneGraphPrivateEndpointAvailableTrigger, NeptuneGraphPrivateEndpointDeletedTrigger, + NeptuneImportTaskCompleteTrigger, ) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.common.compat.sdk import conf @@ -69,7 +73,7 @@ class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html - :return: dictionary with Neptune graph id + :return: dictionary with Neptune graph id and vpc id """ aws_hook_class = NeptuneAnalyticsHook @@ -127,7 +131,7 @@ def execute(self, context: Context) -> dict: } response = self.hook.conn.create_graph(**create_params) - + # TODO: need to return vpc id from response in case it wasn't supplied earlier. self.log.info("Graph %s in status %s", self.graph_name, response.get("status", "Unknown")) self.graph_id = response.get("id", None) @@ -155,14 +159,9 @@ def execute(self, context: Context) -> dict: return {"graph_id": self.graph_id} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - graph_id = "" - - if event: - graph_id = event.get("graph_id", "Unknown") + self.log.info("Neptune graph % complete", self.graph_id) - self.log.info("Neptune graph % complete", graph_id) - - return {"graph_id": graph_id} + return {"graph_id": self.graph_id} class NeptuneCreatePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -171,7 +170,7 @@ class NeptuneCreatePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalytics .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:NeptuneCreateGraphOperator` + :ref:`howto/operator:NeptuneCreatePrivateGraphEndpointOperator` :param graph_identifier: Neptune Graph id :param vpc_id: VPC to create endpoint in @@ -382,3 +381,384 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None vpc_endpoint_id = event.get("endpoint_id", "Unknown") self.log.info("Endpoint id %s deleted", vpc_endpoint_id) + + +class NeptuneDeleteGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Deletes an Amazon Neptune Graph database. + + Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCreateGraphOperator` + + :param graph_id: Name of Neptune graph to create + :param skip_snapshot: Determines whether a final graph snapshot is created before the graph is deleted. If true is specified, no graph snapshot is created. If false is specified, a graph snapshot is created before the graph is deleted. + :param wait_for_completion: Whether to wait for the graph to delete. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the graph to be deleted. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields("graph_id", "skip_snapshot") + + def __init__( + self, + graph_id: str, + skip_snapshot: bool, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_id = graph_id + self.skip_snapshot = skip_snapshot + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> dict: + self.log.info("Deleting graph %s", self.graph_id) + + try: + self.hook.conn.delete_graph(graphIdentifier=self.graph_id, skipSnapshot=self.skip_snapshot) + except ClientError as e: + # if not found, just exit because there is nothing to delete + if e.response["Error"]["Code"] == "ResourceNotFoundException": + self.log.info("Graph %s not found. Nothing to delete", self.graph_id) + else: + raise AirflowException(e.response["Error"]) + + if self.deferrable: + self.log.info("Deferring until graph %s is deleted", self.graph_id) + self.defer( + trigger=NeptuneGraphDeletedTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: + self.log.info("Waiting to delete %s", self.graph_id) + + self.hook.conn.get_waiter("graph_deleted").wait( + graphIdentifier=self.graph_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + if event: + graph_id = event.get("graph_id", "Unknown") + + self.log.info("Neptune graph % deleted", graph_id) + + +class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Creates a Neptune Graph and imports data into it. + + Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune Analytics, + you can get insights and find trends by processing large amounts of graph data in seconds. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCreateGraphWithImportOperator` + + :param graph_name: Name of Neptune graph to create + :param vector_search_config: Specifies the number of dimensions for vector embeddings that will be loaded into the graph. + :param source: The source from which to import data. Can be an S3 URI or Neptune database snapshot. + :param role_arn: The ARN of the IAM role that Neptune Analytics can assume to access the data source. + :param blank_node_handling: The method to handle blank nodes in the dataset. Options include 'convertToIri' or other handling strategies. + :param parquet_type: The type of Parquet files in the data source (if applicable). + :param format: The format of the data to be imported (e.g., 'csv', 'opencypher', 'ntriples', 'nquads', 'rdfxml', 'turtle'). + :param min_provisioned_memory: The minimum provisioned memory for the graph in GBs. + :param max_provisioned_memory: The maximum provisioned memory for the graph in GBs. + :param fail_on_error: If True, the import will fail if any errors are encountered. If False, the import will continue despite errors. + :param public_connectivity: Specifies whether or not the graph can be reachable over the internet. + :param replica_count: The number of replicas in other AZs. + :param deletion_protection: Indicates whether or not to enable deletion protection on the graph. + The graph can't be deleted when deletion protection is enabled. (default: False) + :param kms_key_id: Specifies a KMS key to use to encrypt data in the new graph. + :param tags: Specifies metadata tags to add to the graph. + :param import_options: Contains options for controlling the import process. + :param wait_for_completion: Whether to wait for the graph to be created and data imported. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 30) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 60) + :param deferrable: If True, the operator will wait asynchronously for the graph to be created and data imported. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id and vpc id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields() + + def __init__( + self, + graph_name: str, + vector_search_config: dict, + source: str, + role_arn: str, + blank_node_handling: str | None = None, + parquet_type: str | None = None, + format: str | None = None, + min_provisioned_memory: int | None = None, + max_provisioned_memory: int | None = None, + fail_on_error: bool | None = None, + public_connectivity: bool | None = None, + replica_count: int | None = None, + deletion_protection: bool | None = None, + kms_key_id: str | None = None, + tags: dict | None = None, + import_options: dict | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_name = graph_name + self.vector_search_config = vector_search_config + self.source = source + self.role_arn = role_arn + self.blank_node_handling = blank_node_handling + self.parquet_type = parquet_type + self.format = format + self.min_provisioned_memory = min_provisioned_memory + self.max_provisioned_memory = max_provisioned_memory + self.fail_on_error = fail_on_error + self.public_connectivity = public_connectivity + self.replica_count = replica_count + self.deletion_protect = deletion_protection + self.kms_key = kms_key_id + self.tags = tags + self.import_options = import_options + self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def execute(self, context: Context) -> dict: + self.log.info("Creating graph %s with import", self.graph_name) + + # Build the import options + import_options = { + "neptune-analytics:blank-node-handling": self.blank_node_handling, + "neptune-analytics:parquet-type": self.parquet_type, + } + + # Remove None values from import_options + import_options = {k: v for k, v in import_options.items() if v is not None} + + # Merge with user-provided import_options + if self.import_options: + import_options.update(self.import_options) + + create_params = { + "graphName": self.graph_name, + "vectorSearchConfiguration": self.vector_search_config, + "source": self.source, + "roleArn": self.role_arn, + **{ + k: v + for k, v in { + "format": self.format, + "minProvisionedMemory": self.min_provisioned_memory, + "maxProvisionedMemory": self.max_provisioned_memory, + "failOnError": self.fail_on_error, + "replicaCount": self.replica_count, + "publicConnectivity": self.public_connectivity, + "deletionProtection": self.deletion_protect, + "kmsKeyIdentifier": self.kms_key, + "tags": self.tags, + "importOptions": import_options if import_options else None, + }.items() + if v is not None + }, + } + + response = self.hook.conn.create_graph_using_import_task(**create_params) + + self.log.info("Graph %s import task in status %s", self.graph_name, response.get("status", "Unknown")) + self.graph_id = response.get("graphId", None) + + # TODO build extra link to console + + if self.deferrable: + self.log.info("Deferring until graph %s is available", self.graph_id) + self.defer( + trigger=NeptuneGraphAvailableTrigger( + aws_conn_id=self.aws_conn_id, + graph_id=self.graph_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + + if self.wait_for_completion: + self.log.info("Waiting until graph %s is available", self.graph_id) + self.hook.get_waiter("graph_available").wait( + graphIdentifier=self.graph_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"graph_id": self.graph_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + self.log.info("Trigger complete for graph %s", self.graph_id) + return {"graph_id": self.graph_id} + + +class NeptuneStartImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Starts a bulk data import task to load data into an empty Neptune graph. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneStartImportTaskOperator` + + :param graph_identifier: Graph Id of target Neptune Graph + :param role_arn: IAM role ARN granting access to source data + :param source: URL identifying the source data location. + :param blank_node_handling: Method to handle blank nodes in dataset. + :param fail_on_error: If set to true, the task halts when an import error is encountered. If set to false, the task skips the data that caused the error and continues if possible. + :param format: Specifies the format of the Amazon S3 data to be ipmorted. + :param import_options: Options on how to perform an import + :param parquet_type: Parquet type of import task + :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields( + "graph_identifier", "role_arn", "source", "import_options" + ) + template_fields_renderers = { + "import_options": "json", + } + + def __init__( + self, + graph_identifier: str, + role_arn: str, + source: str, + blank_node_handling: str | None = "convertToIri", + fail_on_error: bool = True, + format: str | None = None, + import_options: dict | None = None, + parquet_type: str | None = "COLUMNAR", + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.graph_identifier = graph_identifier + self.role_arn = role_arn + self.source = source + self.blank_node_handling = blank_node_handling + self.fail_on_error = fail_on_error + self.format = format + self.import_options = import_options + self.parquet_type = parquet_type + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> dict: + self.log.info("Starting data import to graph %s", self.graph_identifier) + + create_params = { + "graphIdentifier": self.graph_identifier, + "roleArn": self.role_arn, + "source": self.source, + **{ + k: v + for k, v in { + "blankNodeHandling": self.blank_node_handling, + "failOnError": self.fail_on_error, + "format": self.format, + "importOptions": self.import_options, + "parquetType": self.parquet_type, + }.items() + if v is not None + }, + } + + response = self.hook.conn.start_import_task(**create_params) + + self.log.info("Import task %s started for graph %s", response.get("taskId"), self.graph_identifier) + task_id = response.get("taskId") + + if self.deferrable: + self.log.info("Deferring until import task %s completes", task_id) + self.defer( + trigger=NeptuneImportTaskCompleteTrigger( + task_id=task_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + + if self.wait_for_completion: + self.log.info("Waiting for import task %s to complete", task_id) + self.hook.get_waiter("import_task_completed").wait( + taskIdentifier=task_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"task_id": task_id, "graph_id": self.graph_identifier} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + task_id = "" + if event: + task_id = event.get("task_id", "") + + return {"graph_id": self.graph_identifier, "task_id": task_id} diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py index f8c0b716044b2..07a6ffe265d8c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -155,3 +155,88 @@ def hook(self) -> AwsGenericHook: verify=self.verify, config=self.botocore_config, ) + + +class NeptuneGraphDeletedTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune Graph is deleted. + + :param graph_id: Graph Id of the endpoint + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + graph_id: str, + endpoint_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"graph_id": graph_id}, + waiter_name="graph_deleted", + waiter_args={"graphIdentifier": graph_id}, + failure_message="Failed to delete Neptune graph", + status_message="Status of Neptune graph is", + status_queries=["status"], + return_key="graph_id", + return_value=graph_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class NeptuneImportTaskCompleteTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune import task successfully completes. + + :param task_id: Import task id to monitor + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + task_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"task_id": task_id}, + waiter_name="import_task_successful", + waiter_args={"taskIdentifier": task_id}, + failure_message="Import task failed", + status_message="Status of import task is", + status_queries=["status"], + return_key="task_id", + return_value=task_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index 57471cf66889b..f779a25e94ce8 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -26,8 +26,11 @@ from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook from airflow.providers.amazon.aws.operators.neptune_analytics import ( NeptuneCreateGraphOperator, + NeptuneCreateGraphWithImportOperator, NeptuneCreatePrivateGraphEndpointOperator, + NeptuneDeleteGraphOperator, NeptuneDeletePrivateGraphEndpointOperator, + NeptuneStartImportTaskOperator, ) GRAPH_NAME = "test_graph" @@ -36,6 +39,8 @@ SUBNET_IDS = ["subnet-1", "subnet-2"] SECURITY_GROUP_IDS = ["sg-1", "sg-2"] ENDPOINT_ID = "vpce-12345" +SOURCE_S3_URI = "s3://my-bucket/my-data/" +ROLE_ARN = "arn:aws:iam::123456789012:role/NeptuneImportRole" @pytest.fixture @@ -418,3 +423,516 @@ def test_execute_complete_success(self): operator.execute_complete(None, event) # Verify the method completes without error and logs the endpoint_id + + +class TestNeptuneDeleteGraphOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.skip_snapshot is True + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.delete_graph.assert_called_once_with( + graphIdentifier=GRAPH_ID, + skipSnapshot=True, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=False, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.skip_snapshot is False + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.delete_graph.assert_called_once_with( + graphIdentifier=GRAPH_ID, + skipSnapshot=False, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_graph_no_wait(self, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + wait_for_completion=False, + ) + operator.execute(None) + + mock_conn.delete_graph.assert_called_once() + mock_conn.get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_graph_wait_for_completion(self, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + mock_waiter = mock.MagicMock() + mock_conn.get_waiter.return_value = mock_waiter + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + wait_for_completion=True, + ) + operator.execute(None) + + mock_conn.get_waiter.assert_called_once_with("graph_deleted") + mock_waiter.wait.assert_called_once_with( + graphIdentifier=GRAPH_ID, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_graph_resource_not_found(self, mock_conn): + from botocore.exceptions import ClientError + + # Simulate ResourceNotFoundException + error_response = { + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Graph not found", + }, + "ResponseMetadata": { + "HTTPStatusCode": 404, + }, + } + mock_conn.delete_graph.side_effect = ClientError(error_response, "delete_graph") + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + ) + + # Should not raise an exception, just log that graph not found + operator.execute(None) + + mock_conn.delete_graph.assert_called_once_with( + graphIdentifier=GRAPH_ID, + skipSnapshot=True, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_graph_other_client_error(self, mock_conn): + from botocore.exceptions import ClientError + + from airflow.providers.common.compat.sdk import AirflowException + + # Simulate other ClientError + error_response = { + "Error": { + "Code": "ValidationException", + "Message": "Invalid parameter", + }, + "ResponseMetadata": { + "HTTPStatusCode": 400, + }, + } + mock_conn.delete_graph.side_effect = ClientError(error_response, "delete_graph") + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + ) + + # Should raise AirflowException for non-ResourceNotFoundException errors + with pytest.raises(AirflowException): + operator.execute(None) + + +class TestNeptuneCreateGraphWithImportOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + ) + + assert operator.graph_name == GRAPH_NAME + assert operator.vector_search_config == {"dimension": 128} + assert operator.source == SOURCE_S3_URI + assert operator.role_arn == ROLE_ARN + assert operator.blank_node_handling is None + assert operator.parquet_type is None + assert operator.format is None + assert operator.min_provisioned_memory is None + assert operator.max_provisioned_memory is None + assert operator.fail_on_error is None + assert operator.public_connectivity is None + assert operator.replica_count is None + assert operator.deletion_protect is None + assert operator.kms_key is None + assert operator.tags is None + assert operator.import_options is None + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.create_graph_using_import_task.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"dimension": 128}, + source=SOURCE_S3_URI, + roleArn=ROLE_ARN, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_with_all_optional_params(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + blank_node_handling="convertToIri", + parquet_type="COLUMNAR", + format="csv", + min_provisioned_memory=16, + max_provisioned_memory=32, + fail_on_error=True, + public_connectivity=True, + replica_count=2, + deletion_protection=True, + kms_key_id="test-kms-key", + tags={"env": "test"}, + import_options={"custom-option": "value"}, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.blank_node_handling == "convertToIri" + assert operator.parquet_type == "COLUMNAR" + assert operator.format == "csv" + assert operator.min_provisioned_memory == 16 + assert operator.max_provisioned_memory == 32 + assert operator.fail_on_error is True + assert operator.public_connectivity is True + assert operator.replica_count == 2 + assert operator.deletion_protect is True + assert operator.kms_key == "test-kms-key" + assert operator.tags == {"env": "test"} + assert operator.import_options == {"custom-option": "value"} + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + # Verify the call includes all parameters + call_args = mock_conn.create_graph_using_import_task.call_args[1] + assert call_args["graphName"] == GRAPH_NAME + assert call_args["vectorSearchConfiguration"] == {"dimension": 128} + assert call_args["source"] == SOURCE_S3_URI + assert call_args["roleArn"] == ROLE_ARN + assert call_args["format"] == "csv" + assert call_args["minProvisionedMemory"] == 16 + assert call_args["maxProvisionedMemory"] == 32 + assert call_args["failOnError"] is True + assert call_args["replicaCount"] == 2 + assert call_args["publicConnectivity"] is True + assert call_args["deletionProtection"] is True + assert call_args["kmsKeyIdentifier"] == "test-kms-key" + assert call_args["tags"] == {"env": "test"} + # Check import options were merged + assert "neptune-analytics:blank-node-handling" in call_args["importOptions"] + assert call_args["importOptions"]["neptune-analytics:blank-node-handling"] == "convertToIri" + assert "neptune-analytics:parquet-type" in call_args["importOptions"] + assert call_args["importOptions"]["neptune-analytics:parquet-type"] == "COLUMNAR" + assert call_args["importOptions"]["custom-option"] == "value" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_import_options_handling(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + blank_node_handling="convertToIri", + import_options={"another-option": "test"}, + ) + + operator.execute(None) + + call_args = mock_conn.create_graph_using_import_task.call_args[1] + # Verify import options were properly merged + assert call_args["importOptions"]["neptune-analytics:blank-node-handling"] == "convertToIri" + assert call_args["importOptions"]["another-option"] == "test" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph_with_import_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + wait_for_completion=False, + ) + result = operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + assert result == {"graph_id": GRAPH_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph_with_import_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + wait_for_completion=True, + ) + result = operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("graph_available") + assert result == {"graph_id": GRAPH_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_import_options_none_values_filtered(self, mock_conn): + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + # Test that None values in blank_node_handling and parquet_type are filtered out + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + blank_node_handling=None, + parquet_type=None, + ) + + operator.execute(None) + + call_args = mock_conn.create_graph_using_import_task.call_args[1] + # importOptions should not be in call_args if all values are None + assert "importOptions" not in call_args + + +TASK_ID = "import-task-id-12345" + + +class TestNeptuneStartImportTaskOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.role_arn == ROLE_ARN + assert operator.source == SOURCE_S3_URI + assert operator.blank_node_handling == "convertToIri" + assert operator.fail_on_error is True + assert operator.format is None + assert operator.import_options is None + assert operator.parquet_type == "COLUMNAR" + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.start_import_task.assert_called_once_with( + graphIdentifier=GRAPH_ID, + roleArn=ROLE_ARN, + source=SOURCE_S3_URI, + blankNodeHandling="convertToIri", + failOnError=True, + parquetType="COLUMNAR", + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + blank_node_handling=None, + fail_on_error=False, + format="CSV", + import_options={"neptune.csv.allowEmptyStrings": True}, + parquet_type=None, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.blank_node_handling is None + assert operator.fail_on_error is False + assert operator.format == "CSV" + assert operator.import_options == {"neptune.csv.allowEmptyStrings": True} + assert operator.parquet_type is None + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.start_import_task.assert_called_once_with( + graphIdentifier=GRAPH_ID, + roleArn=ROLE_ARN, + source=SOURCE_S3_URI, + failOnError=False, + format="CSV", + importOptions={"neptune.csv.allowEmptyStrings": True}, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_start_import_no_wait(self, mock_get_waiter, mock_conn): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + wait_for_completion=False, + ) + result = operator.execute(None) + + mock_get_waiter.assert_not_called() + assert result == {"task_id": TASK_ID, "graph_id": GRAPH_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_start_import_wait_for_completion(self, mock_get_waiter, mock_conn): + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + wait_for_completion=True, + ) + result = operator.execute(None) + + mock_get_waiter.assert_called_once_with("import_task_completed") + assert result == {"task_id": TASK_ID, "graph_id": GRAPH_ID} + + def test_execute_complete_success(self): + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + ) + + event = {"status": "success", "task_id": TASK_ID} + result = operator.execute_complete(None, event) + + assert result == {"graph_id": GRAPH_ID, "task_id": TASK_ID} + + def test_execute_complete_no_event(self): + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + ) + + result = operator.execute_complete(None, None) + + assert result == {"graph_id": GRAPH_ID, "task_id": ""} diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index 1637b0a15e8da..9c6a215583ed8 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -26,6 +26,7 @@ NeptuneGraphAvailableTrigger, NeptuneGraphPrivateEndpointAvailableTrigger, NeptuneGraphPrivateEndpointDeletedTrigger, + NeptuneImportTaskCompleteTrigger, ) from airflow.providers.common.compat.sdk import AirflowException from airflow.triggers.base import TriggerEvent @@ -33,6 +34,7 @@ GRAPH_ID = "test-graph" VPC_ID = "test-vpc" ENDPOINT_ID = "test-endpoint" +TASK_ID = "test-task-id" class TestNeptuneGraphAvailableTrigger: @@ -191,3 +193,49 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): with pytest.raises(AirflowException): await trigger.run().asend(None) + + +class TestNeptuneImportTaskCompleteTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneImportTaskCompleteTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneImportTaskCompleteTrigger(task_id=TASK_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneImportTaskCompleteTrigger" + ) + assert "task_id" in kwargs + assert kwargs["task_id"] == TASK_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "COMPLETED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneImportTaskCompleteTrigger(task_id=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "task_id": TASK_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="import_task_successful", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "taskIdentifier": TASK_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneImportTaskCompleteTrigger(task_id=TASK_ID) + + with pytest.raises(AirflowException): + await trigger.run().asend(None) From d86498d1384fb3228e08f7516224972ff1bf99dd Mon Sep 17 00:00:00 2001 From: mse139 Date: Mon, 16 Mar 2026 06:27:00 -0400 Subject: [PATCH 09/28] Fixed prek findings --- .../providers/amazon/aws/operators/neptune_analytics.py | 4 ++-- .../providers/amazon/aws/triggers/neptune_analytics.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index a62e19175d34f..26c3603a98dbe 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -434,7 +434,7 @@ def __init__( self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts - def execute(self, context: Context) -> dict: + def execute(self, context: Context): self.log.info("Deleting graph %s", self.graph_id) try: @@ -465,7 +465,7 @@ def execute(self, context: Context) -> dict: WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None): if event: graph_id = event.get("graph_id", "Unknown") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py index 07a6ffe265d8c..cf286ed9501cc 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -161,7 +161,7 @@ class NeptuneGraphDeletedTrigger(AwsBaseWaiterTrigger): """ Triggers when a Neptune Graph is deleted. - :param graph_id: Graph Id of the endpoint + :param graph_id: Graph Id to be deleted :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. @@ -172,7 +172,6 @@ def __init__( self, *, graph_id: str, - endpoint_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, **kwargs, From 6b5c88fbb049af7133e4dd98eca5e98f226ca24a Mon Sep 17 00:00:00 2001 From: mse139 Date: Mon, 16 Mar 2026 10:39:39 -0400 Subject: [PATCH 10/28] Added NeptuneCancelImportTaskOperator --- .../amazon/aws/operators/neptune_analytics.py | 85 +++++++++++++++ .../amazon/aws/triggers/neptune_analytics.py | 42 +++++++ .../aws/operators/test_neptune_analytics.py | 103 ++++++++++++++++++ .../aws/triggers/test_neptune_analytics.py | 47 ++++++++ 4 files changed, 277 insertions(+) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 26c3603a98dbe..8b9cc07a1c007 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -30,6 +30,7 @@ NeptuneGraphDeletedTrigger, NeptuneGraphPrivateEndpointAvailableTrigger, NeptuneGraphPrivateEndpointDeletedTrigger, + NeptuneImportTaskCancelledTrigger, NeptuneImportTaskCompleteTrigger, ) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields @@ -762,3 +763,87 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None task_id = event.get("task_id", "") return {"graph_id": self.graph_identifier, "task_id": task_id} + + +class NeptuneCancelImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): + """ + Cancels an active Neptune Graph import task. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:NeptuneCancelImportTaskOperator + + :param task_identifier: Neptune Graph import task id to cancel. + :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :return: dictionary with Neptune graph id + + """ + + aws_hook_class = NeptuneAnalyticsHook + template_fields: Sequence[str] = aws_template_fields("task_identifier") + + def __init__( + self, + task_identifier: str, + wait_for_completion: bool = True, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.import_task_id = task_identifier + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> dict: + self.log.info("Cancelling import task %s", self.import_task_id) + + response = self.hook.conn.cancel_import_task(taskIdentifier=self.import_task_id) + + self.log.info("Import task %s status is %s", self.import_task_id, response.get("status", "Unknown")) + + if self.deferrable: + self.log.info("Deferring until import task %s is cancelled", self.import_task_id) + self.defer( + trigger=NeptuneImportTaskCancelledTrigger( + task_identifier=self.import_task_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + + if self.wait_for_completion: + self.log.info("Waiting for import task %s to be cancelled", self.import_task_id) + self.hook.get_waiter("import_task_cancelled").wait( + taskIdentifier=self.import_task_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return {"task_identifier": self.import_task_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + task_id = "" + if event: + task_id = event.get("task_identifier", "") + self.log.info("Import task %s cancelled", task_id) + + return {"task_identifier": task_id} diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py index cf286ed9501cc..959ba5eee1987 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -239,3 +239,45 @@ def hook(self) -> AwsGenericHook: verify=self.verify, config=self.botocore_config, ) + + +class NeptuneImportTaskCancelledTrigger(AwsBaseWaiterTrigger): + """ + Triggers when a Neptune import task is successfully cancelled. + + :param task_id: Import task id to monitor. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region name (example: us-east-1) + """ + + def __init__( + self, + *, + task_identifier: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + **kwargs, + ) -> None: + super().__init__( + serialized_fields={"task_identifier": task_identifier}, + waiter_name="import_task_cancelled", + waiter_args={"taskIdentifier": task_identifier}, + failure_message="Import task cancellation failed", + status_message="Status of import task is", + status_queries=["status"], + return_key="task_identifier", + return_value=task_identifier, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return NeptuneAnalyticsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index f779a25e94ce8..0696fe90684e7 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook from airflow.providers.amazon.aws.operators.neptune_analytics import ( + NeptuneCancelImportTaskOperator, NeptuneCreateGraphOperator, NeptuneCreateGraphWithImportOperator, NeptuneCreatePrivateGraphEndpointOperator, @@ -936,3 +937,105 @@ def test_execute_complete_no_event(self): result = operator.execute_complete(None, None) assert result == {"graph_id": GRAPH_ID, "task_id": ""} + + +class TestNeptuneCancelImportTaskOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + task_identifier=TASK_ID, + ) + + assert operator.import_task_id == TASK_ID + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.cancel_import_task.assert_called_once_with(taskIdentifier=TASK_ID) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + task_identifier=TASK_ID, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.import_task_id == TASK_ID + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_cancel_no_wait(self, mock_get_waiter, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + task_identifier=TASK_ID, + wait_for_completion=False, + ) + result = operator.execute(None) + + mock_get_waiter.assert_not_called() + assert result == {"task_identifier": TASK_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_cancel_wait_for_completion(self, mock_get_waiter, mock_conn): + mock_conn.cancel_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "CANCELLING", + } + + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + task_identifier=TASK_ID, + wait_for_completion=True, + ) + result = operator.execute(None) + + mock_get_waiter.assert_called_once_with("import_task_cancelled") + assert result == {"task_identifier": TASK_ID} + + def test_execute_complete_success(self): + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + task_identifier=TASK_ID, + ) + + event = {"status": "success", "task_identifier": TASK_ID} + result = operator.execute_complete(None, event) + + assert result == {"task_identifier": TASK_ID} + + def test_execute_complete_no_event(self): + operator = NeptuneCancelImportTaskOperator( + task_id="test_task", + task_identifier=TASK_ID, + ) + + result = operator.execute_complete(None, None) + + assert result == {"task_identifier": ""} diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index 9c6a215583ed8..637dd23453bc3 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -26,6 +26,7 @@ NeptuneGraphAvailableTrigger, NeptuneGraphPrivateEndpointAvailableTrigger, NeptuneGraphPrivateEndpointDeletedTrigger, + NeptuneImportTaskCancelledTrigger, NeptuneImportTaskCompleteTrigger, ) from airflow.providers.common.compat.sdk import AirflowException @@ -239,3 +240,49 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): with pytest.raises(AirflowException): await trigger.run().asend(None) + + +class TestNeptuneImportTaskCancelledTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneImportTaskCancelledTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneImportTaskCancelledTrigger(task_identifier=TASK_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneImportTaskCancelledTrigger" + ) + assert "task_identifier" in kwargs + assert kwargs["task_identifier"] == TASK_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "CANCELLED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneImportTaskCancelledTrigger(task_identifier=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "task_identifier": TASK_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="import_task_cancelled", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "taskIdentifier": TASK_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneImportTaskCancelledTrigger(task_identifier=TASK_ID) + + with pytest.raises(AirflowException): + await trigger.run().asend(None) From 398b4d398ab0d25f7393f2527053227b7bbdf04f Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Mon, 23 Mar 2026 15:56:46 +0000 Subject: [PATCH 11/28] Added system tests and fixed errors found during system testing --- .../amazon/aws/operators/neptune_analytics.py | 111 +++--- .../amazon/aws/triggers/neptune_analytics.py | 7 +- .../amazon/aws/example_neptune_analytics.py | 319 ++++++++++++++++++ .../aws/operators/test_neptune_analytics.py | 175 ++++++++-- 4 files changed, 540 insertions(+), 72 deletions(-) create mode 100644 providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 8b9cc07a1c007..757453b54ae38 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -214,7 +214,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.graph_id = graph_identifier + self.graph_identifier = graph_identifier self.vpc_id = vpc_id self.subnet_ids = subnet_ids self.vpc_security_group_ids = vpc_security_group_ids @@ -224,10 +224,10 @@ def __init__( self.waiter_max_attempts = waiter_max_attempts def execute(self, context: Context) -> dict: - self.log.info("Creating private endpoint for graph %s", self.graph_id) + self.log.info("Creating private endpoint for graph %s", self.graph_identifier) create_params = { - "graphIdentifier": self.graph_id, + "graphIdentifier": self.graph_identifier, **{ k: v for k, v in { @@ -243,50 +243,52 @@ def execute(self, context: Context) -> dict: result = self.hook.conn.create_private_graph_endpoint(**create_params) status = result.get("status", "Unknown") - endpoint_id = result.get("vpcEndpointId", "Unknown") - self.log.info("Status of endpoint %s: %s", endpoint_id, status) + self.log.info("Status of endpoint: %s", status) if status in ["FAILED"]: - raise AirflowException(f"Private endpoint failed to create for graph {self.graph_id}") + raise AirflowException(f"Private endpoint failed to create for graph {self.graph_identifier}") - # if VPC not provided, use the one that is returned. Required for the waiter + # if VPC not provided, use the one that is returned, which is the default VPC. Required for the waiter self.vpc_id = result.get("vpcId", self.vpc_id) # TODO extra link to console if self.deferrable: - self.log.info("Deferring until endpoint %s is available", endpoint_id) + self.log.info("Deferring until endpoint is available") self.defer( trigger=NeptuneGraphPrivateEndpointAvailableTrigger( aws_conn_id=self.aws_conn_id, - graph_id=self.graph_id, + graph_id=self.graph_identifier, vpc_id=self.vpc_id, - endpoint_id=endpoint_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, ), method_name="execute_complete", ) - # TODO add test if self.wait_for_completion: - self.log.info("Waiting until endpoint %s is available", endpoint_id) + self.log.info("Waiting until endpoint is available") self.hook.get_waiter("private_graph_endpoint_available").wait( - graphIdentifier=self.graph_id, + graphIdentifier=self.graph_identifier, vpcId=self.vpc_id, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - return {"vpc_endpoint_id": endpoint_id, "graph_id": self.graph_id, "vpc_id": self.vpc_id} + endpoint_id = self._get_graph_endpoint_id() + return {"vpc_endpoint_id": endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - vpc_endpoint_id = "" + def _get_graph_endpoint_id(self): + """Return the vpc endpoint id for this graph.""" + result = self.hook.conn.get_private_graph_endpoint( + graphIdentifier=self.graph_identifier, vpcId=self.vpc_id + ) + return result.get("vpcEndpointId") - if event and event.get("status") == "success": - vpc_endpoint_id = event.get("endpoint_id", "") + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + vpc_endpoint_id = self._get_graph_endpoint_id() - return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_id, "vpc_id": self.vpc_id} + return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} class NeptuneDeletePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -334,7 +336,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.graph_id = graph_identifier + self.graph_identifier = graph_identifier self.vpc_id = vpc_id self.wait_for_completion = wait_for_completion self.deferrable = deferrable @@ -342,10 +344,10 @@ def __init__( self.waiter_max_attempts = waiter_max_attempts def execute(self, context: Context) -> None: - self.log.info("Deleting private endpoint for graph %s", self.graph_id) + self.log.info("Deleting private endpoint for graph %s", self.graph_identifier) result = self.hook.conn.delete_private_graph_endpoint( - graphIdentifier=self.graph_id, vpcId=self.vpc_id + graphIdentifier=self.graph_identifier, vpcId=self.vpc_id ) status = result.get("status") @@ -359,7 +361,7 @@ def execute(self, context: Context) -> None: self.defer( trigger=NeptuneGraphPrivateEndpointDeletedTrigger( aws_conn_id=self.aws_conn_id, - graph_id=self.graph_id, + graph_id=self.graph_identifier, vpc_id=self.vpc_id, endpoint_id=endpoint_id, waiter_delay=self.waiter_delay, @@ -370,10 +372,11 @@ def execute(self, context: Context) -> None: if self.wait_for_completion: self.log.info("Waiting until endpoint %s is deleted", endpoint_id) self.hook.get_waiter("private_graph_endpoint_deleted").wait( - graphIdentifier=self.graph_id, + graphIdentifier=self.graph_identifier, vpcId=self.vpc_id, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) + self.log.info("Endpoint %s deleted", endpoint_id) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: vpc_endpoint_id = "" @@ -519,7 +522,13 @@ class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook] """ aws_hook_class = NeptuneAnalyticsHook - template_fields: Sequence[str] = aws_template_fields() + template_fields: Sequence[str] = aws_template_fields( + "graph_name", "vector_search_config", "source", "role_arn", "kms_key" + ) + + template_fields_renderers = { + "vector_search_config": "json", + } def __init__( self, @@ -610,9 +619,10 @@ def execute(self, context: Context) -> dict: self.log.info("Graph %s import task in status %s", self.graph_name, response.get("status", "Unknown")) self.graph_id = response.get("graphId", None) + import_task_id = response.get("taskId") # TODO build extra link to console - + # TODO - second defer for task completion. if self.deferrable: self.log.info("Deferring until graph %s is available", self.graph_id) self.defer( @@ -622,7 +632,8 @@ def execute(self, context: Context) -> dict: waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, ), - method_name="execute_complete", + method_name="defer_wait_for_task", + kwargs={"import_task_id": import_task_id}, ) if self.wait_for_completion: @@ -631,11 +642,33 @@ def execute(self, context: Context) -> dict: graphIdentifier=self.graph_id, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) + # Once the graph is available, wait for the task to complete + + self.log.info("Waiting for import task %s", import_task_id) + self.hook.get_waiter("import_task_successful").wait( + taskIdentifier=import_task_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) return {"graph_id": self.graph_id} + def defer_wait_for_task( + self, import_task_id: str, context: Context, event: dict[str, Any] | None = None + ) -> None: + """Defers for import task completion.""" + self.log.info("Deferring for import task %s completion", import_task_id) + self.defer( + trigger=NeptuneImportTaskCompleteTrigger( + task_id=import_task_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - self.log.info("Trigger complete for graph %s", self.graph_id) + self.log.info("Import complete for graph %s", self.graph_id) return {"graph_id": self.graph_id} @@ -686,7 +719,7 @@ def __init__( graph_identifier: str, role_arn: str, source: str, - blank_node_handling: str | None = "convertToIri", + blank_node_handling: str | None = None, fail_on_error: bool = True, format: str | None = None, import_options: dict | None = None, @@ -750,7 +783,7 @@ def execute(self, context: Context) -> dict: if self.wait_for_completion: self.log.info("Waiting for import task %s to complete", task_id) - self.hook.get_waiter("import_task_completed").wait( + self.hook.get_waiter("import_task_successful").wait( taskIdentifier=task_id, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) @@ -806,24 +839,24 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.import_task_id = task_identifier + self.task_identifier = task_identifier self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable def execute(self, context: Context) -> dict: - self.log.info("Cancelling import task %s", self.import_task_id) + self.log.info("Cancelling import task %s", self.task_identifier) - response = self.hook.conn.cancel_import_task(taskIdentifier=self.import_task_id) + response = self.hook.conn.cancel_import_task(taskIdentifier=self.task_identifier) - self.log.info("Import task %s status is %s", self.import_task_id, response.get("status", "Unknown")) + self.log.info("Import task %s status is %s", self.task_identifier, response.get("status", "Unknown")) if self.deferrable: - self.log.info("Deferring until import task %s is cancelled", self.import_task_id) + self.log.info("Deferring until import task %s is cancelled", self.task_identifier) self.defer( trigger=NeptuneImportTaskCancelledTrigger( - task_identifier=self.import_task_id, + task_identifier=self.task_identifier, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, @@ -832,13 +865,13 @@ def execute(self, context: Context) -> dict: ) if self.wait_for_completion: - self.log.info("Waiting for import task %s to be cancelled", self.import_task_id) + self.log.info("Waiting for import task %s to be cancelled", self.task_identifier) self.hook.get_waiter("import_task_cancelled").wait( - taskIdentifier=self.import_task_id, + taskIdentifier=self.task_identifier, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - return {"task_identifier": self.import_task_id} + return {"task_identifier": self.task_identifier} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: task_id = "" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py index 959ba5eee1987..8deeb2d63faa2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -84,20 +84,19 @@ def __init__( *, graph_id: str, vpc_id: str, - endpoint_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, **kwargs, ) -> None: super().__init__( - serialized_fields={"graph_id": graph_id, "vpc_id": vpc_id, "endpoint_id": endpoint_id}, + serialized_fields={"graph_id": graph_id, "vpc_id": vpc_id}, waiter_name="private_graph_endpoint_available", waiter_args={"graphIdentifier": graph_id, "vpcId": vpc_id}, failure_message="Failed to create Neptune graph endpoint", status_message="Status of Neptune graph endpoint is", status_queries=["status"], - return_key="endpoint_id", - return_value=endpoint_id, + return_key="graph_id", + return_value=graph_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, **kwargs, diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py new file mode 100644 index 0000000000000..2b5050c34b899 --- /dev/null +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -0,0 +1,319 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import time +from datetime import datetime + +import boto3 + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.operators.neptune_analytics import ( + NeptuneCreateGraphOperator, + NeptuneCreateGraphWithImportOperator, + NeptuneCreatePrivateGraphEndpointOperator, + NeptuneDeleteGraphOperator, + NeptuneDeletePrivateGraphEndpointOperator, + NeptuneStartImportTaskOperator, +) +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3CreateObjectOperator, + S3DeleteBucketOperator, +) +from airflow.providers.common.compat.sdk import DAG, chain, task + +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +from system.amazon.aws.utils import SystemTestContextBuilder + +DAG_ID = "example_neptune_analytics" + +ROLE_ARN_KEY = "ROLE_ARN" + +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() + +# Minimal OpenCypher CSV data for import testing. +NODES_CSV = """~id,~label,name:String +n1,Person,Alice +n2,Person,Bob +""" + +EDGES_CSV = """~id,~from,~to,~label +e1,n1,n2,KNOWS +""" + +NEPTUNE_ANALYTICS_TRUST_POLICY = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "neptune-graph.amazonaws.com"}, + "Action": "sts:AssumeRole", + } + ], + } +) + +S3_READ_POLICY_DOCUMENT = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": ["s3:GetObject", "s3:ListBucket"], + "Resource": ["arn:aws:s3:::*", "arn:aws:s3:::*/*"], + } + ], + } +) + + +@task +def create_neptune_import_role(role_name: str) -> str: + iam_client = boto3.client("iam") + iam_client.create_role( + RoleName=role_name, + AssumeRolePolicyDocument=NEPTUNE_ANALYTICS_TRUST_POLICY, + Description="Role for Neptune Analytics import system test", + ) + iam_client.put_role_policy( + RoleName=role_name, + PolicyName="NeptuneAnalyticsS3Access", + PolicyDocument=S3_READ_POLICY_DOCUMENT, + ) + role = iam_client.get_role(RoleName=role_name) + time.sleep(60) # Wait for IAM eventual consistency (role + inline policy propagation) + return role["Role"]["Arn"] + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_neptune_import_role(role_name: str) -> None: + iam_client = boto3.client("iam") + try: + iam_client.delete_role_policy(RoleName=role_name, PolicyName="NeptuneAnalyticsS3Access") + except iam_client.exceptions.NoSuchEntityException: + pass + try: + iam_client.delete_role(RoleName=role_name) + except iam_client.exceptions.NoSuchEntityException: + pass + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_graph_if_exists(graph_name: str) -> None: + """Safety net to clean up the graph in case a previous task failed.""" + hook = NeptuneAnalyticsHook() + try: + # List graphs and find by name + paginator = hook.conn.get_paginator("list_graphs") + for page in paginator.paginate(): + for graph in page.get("graphs", []): + if graph.get("name") == graph_name: + graph_id = graph["id"] + # Disable deletion protection if enabled + try: + hook.conn.update_graph(graphIdentifier=graph_id, deletionProtection=False) + except Exception: + pass + hook.conn.delete_graph(graphIdentifier=graph_id, skipSnapshot=True) + hook.conn.get_waiter("graph_deleted").wait( + graphIdentifier=graph_id, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + return + except Exception: + pass + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, +) as dag: + test_context = sys_test_context_task() + + env_id = test_context["ENV_ID"] + graph_name = f"{env_id}-graph" + import_graph_name = f"{env_id}-import-graph" + bucket_name = f"{env_id}-neptune-analytics" + import_role_name = f"{env_id}-neptune-import" + region = boto3.session.Session().region_name + + # --- TEST SETUP --- + + create_bucket = S3CreateBucketOperator( + task_id="create_bucket", + bucket_name=bucket_name, + ) + + upload_nodes = S3CreateObjectOperator( + task_id="upload_nodes", + s3_bucket=bucket_name, + s3_key="data/nodes.csv", + data=NODES_CSV, + replace=True, + ) + + upload_edges = S3CreateObjectOperator( + task_id="upload_edges", + s3_bucket=bucket_name, + s3_key="data/edges.csv", + data=EDGES_CSV, + replace=True, + ) + + create_role = create_neptune_import_role(import_role_name) + + # --- TEST BODY --- + + # [START howto_operator_neptune_analytics_create_graph] + create_graph = NeptuneCreateGraphOperator( + task_id="create_graph", + graph_name=graph_name, + vector_search_config={"dimension": 128}, + provisioned_memory=32, + public_connectivity=True, + replica_count=0, + deletion_protection=False, + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_create_graph] + + # [START howto_operator_neptune_analytics_create_private_endpoint] + create_endpoint = NeptuneCreatePrivateGraphEndpointOperator( + task_id="create_endpoint", + graph_identifier="{{ ti.xcom_pull(task_ids='create_graph)['graph_id']}}", + wait_for_completion=True, + ) + + # [END howto_operator_neptune_analytics_create_private_endpoint] + + # [START howto_operator_neptune_analytics_start_import_task] + start_import = NeptuneStartImportTaskOperator( + task_id="start_import", + graph_identifier="{{ ti.xcom_pull(task_ids='create_graph')['graph_id'] }}", + role_arn=create_role, + source=f"s3://{bucket_name}/data/", + format="CSV", + fail_on_error=True, + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_start_import_task] + + # [START howto_operator_neptune_analytics_delete_graph] + delete_graph = NeptuneDeletePrivateGraphEndpointOperator( + task_id="delete_graph", + graph_id="{{ ti.xcom_pull(task_ids='create_graph')['graph_id'] }}", + skip_snapshot=True, + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_delete_graph] + + # [START howto_operator_neptune_analytics_create_graph_with_import] + create_graph_with_import = NeptuneCreateGraphWithImportOperator( + task_id="create_graph_with_import", + graph_name=import_graph_name, + vector_search_config={"dimension": 128}, + source=f"s3://{bucket_name}/data/", + role_arn=create_role, + format="CSV", + fail_on_error=True, + public_connectivity=True, + replica_count=0, + deletion_protection=False, + min_provisioned_memory=32, + max_provisioned_memory=32, + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_create_graph_with_import] + + # [START howto_operator_neptune_analytics_delete_import_graph] + delete_import_graph = NeptuneDeleteGraphOperator( + task_id="delete_import_graph", + graph_id="{{ ti.xcom_pull(task_ids='create_graph_with_import')['graph_id'] }}", + skip_snapshot=True, + wait_for_completion=True, + deferrable=False, + trigger_rule=TriggerRule.ALL_DONE, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_delete_import_graph] + + # --- TEST TEARDOWN --- + + delete_bucket = S3DeleteBucketOperator( + task_id="delete_bucket", + trigger_rule=TriggerRule.ALL_DONE, + bucket_name=bucket_name, + force_delete=True, + ) + + delete_role = delete_neptune_import_role(import_role_name) + + cleanup_graph = delete_graph_if_exists(graph_name) + cleanup_import_graph = delete_graph_if_exists(import_graph_name) + + chain( + # TEST SETUP + test_context, + create_bucket, + [upload_nodes, upload_edges], + create_role, + # TEST BODY: Create graph, import data, then delete + create_graph, + start_import, + delete_graph, + # TEST BODY: Create graph with import, then delete + create_graph_with_import, + delete_import_graph, + # TEST TEARDOWN + [cleanup_graph, cleanup_import_graph], + delete_bucket, + delete_role, + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index 0696fe90684e7..b31605ad15c82 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -33,6 +33,7 @@ NeptuneDeletePrivateGraphEndpointOperator, NeptuneStartImportTaskOperator, ) +from airflow.providers.common.compat.sdk import TaskDeferred GRAPH_NAME = "test_graph" GRAPH_ID = "test-graph-id" @@ -149,13 +150,16 @@ def test_init_defaults(self, mock_conn): "vpcEndpointId": ENDPOINT_ID, "vpcId": VPC_ID, } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } operator = NeptuneCreatePrivateGraphEndpointOperator( task_id="test_task", graph_identifier=GRAPH_ID, ) - assert operator.graph_id == GRAPH_ID + assert operator.graph_identifier == GRAPH_ID assert operator.vpc_id is None assert operator.subnet_ids is None assert operator.vpc_security_group_ids is None @@ -168,6 +172,10 @@ def test_init_defaults(self, mock_conn): mock_conn.create_private_graph_endpoint.assert_called_once_with( graphIdentifier=GRAPH_ID, ) + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) assert result is not None assert result["vpc_endpoint_id"] == ENDPOINT_ID @@ -181,6 +189,9 @@ def test_init_custom_args(self, mock_conn): "vpcEndpointId": ENDPOINT_ID, "vpcId": VPC_ID, } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } operator = NeptuneCreatePrivateGraphEndpointOperator( task_id="test_task", @@ -192,7 +203,7 @@ def test_init_custom_args(self, mock_conn): waiter_max_attempts=100, ) - assert operator.graph_id == GRAPH_ID + assert operator.graph_identifier == GRAPH_ID assert operator.vpc_id == VPC_ID assert operator.subnet_ids == SUBNET_IDS assert operator.vpc_security_group_ids == SECURITY_GROUP_IDS @@ -216,6 +227,9 @@ def test_create_endpoint_no_wait(self, mock_hook_get_waiter, mock_conn): "vpcEndpointId": ENDPOINT_ID, "vpcId": VPC_ID, } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } operator = NeptuneCreatePrivateGraphEndpointOperator( task_id="test_task", @@ -235,6 +249,9 @@ def test_create_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_co "vpcEndpointId": ENDPOINT_ID, "vpcId": VPC_ID, } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } operator = NeptuneCreatePrivateGraphEndpointOperator( task_id="test_task", @@ -242,11 +259,39 @@ def test_create_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_co vpc_id=VPC_ID, wait_for_completion=True, ) - operator.execute(None) + result = operator.execute(None) - # Note: The operator currently has 'pass' for wait_for_completion - # This test documents the current behavior - # When wait_for_completion is implemented, this test should verify the waiter is called + mock_hook_get_waiter.assert_called_once_with("private_graph_endpoint_available") + mock_hook_get_waiter.return_value.wait.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_create_endpoint_sets_vpc_id_from_response(self, mock_conn): + """When vpc_id is not provided, the operator should use the vpc_id from the API response.""" + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + ) + + assert operator.vpc_id is None + result = operator.execute(None) + + # vpc_id should be set from the create response + assert operator.vpc_id == VPC_ID + assert result["vpc_id"] == VPC_ID @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_create_endpoint_failed_status(self, mock_conn): @@ -267,33 +312,41 @@ def test_create_endpoint_failed_status(self, mock_conn): with pytest.raises(AirflowException, match=f"Private endpoint failed to create for graph {GRAPH_ID}"): operator.execute(None) - def test_execute_complete_success(self): + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_execute_complete(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + operator = NeptuneCreatePrivateGraphEndpointOperator( task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID ) - event = { - "status": "success", - "endpoint_id": ENDPOINT_ID, - } - - result = operator.execute_complete(None, event) + result = operator.execute_complete(None, {"status": "success"}) + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} - def test_execute_complete_failure_status(self): + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_get_graph_endpoint_id(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + operator = NeptuneCreatePrivateGraphEndpointOperator( task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID ) - event = { - "status": "failure", - "endpoint_id": ENDPOINT_ID, - } + result = operator._get_graph_endpoint_id() - result = operator.execute_complete(None, event) - - assert result == {"vpc_endpoint_id": "", "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + assert result == ENDPOINT_ID + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) class TestNeptuneDeletePrivateGraphEndpointOperator: @@ -311,7 +364,7 @@ def test_init_defaults(self, mock_conn): vpc_id=VPC_ID, ) - assert operator.graph_id == GRAPH_ID + assert operator.graph_identifier == GRAPH_ID assert operator.vpc_id == VPC_ID assert operator.wait_for_completion is True assert operator.waiter_delay == 30 @@ -340,7 +393,7 @@ def test_init_custom_args(self, mock_conn): waiter_max_attempts=100, ) - assert operator.graph_id == GRAPH_ID + assert operator.graph_identifier == GRAPH_ID assert operator.vpc_id == VPC_ID assert operator.waiter_delay == 60 assert operator.waiter_max_attempts == 100 @@ -585,10 +638,13 @@ def test_delete_graph_other_client_error(self, mock_conn): class TestNeptuneCreateGraphWithImportOperator: + IMPORT_TASK_ID = "import-task-12345" + @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_init_defaults(self, mock_conn): mock_conn.create_graph_using_import_task.return_value = { "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, "status": "IMPORTING", } @@ -633,6 +689,7 @@ def test_init_defaults(self, mock_conn): def test_init_with_all_optional_params(self, mock_conn): mock_conn.create_graph_using_import_task.return_value = { "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, "status": "IMPORTING", } @@ -701,6 +758,7 @@ def test_init_with_all_optional_params(self, mock_conn): def test_import_options_handling(self, mock_conn): mock_conn.create_graph_using_import_task.return_value = { "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, "status": "IMPORTING", } @@ -726,6 +784,7 @@ def test_import_options_handling(self, mock_conn): def test_create_graph_with_import_no_wait(self, mock_hook_get_waiter, mock_conn): mock_conn.create_graph_using_import_task.return_value = { "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, "status": "IMPORTING", } @@ -747,6 +806,7 @@ def test_create_graph_with_import_no_wait(self, mock_hook_get_waiter, mock_conn) def test_create_graph_with_import_wait_for_completion(self, mock_hook_get_waiter, mock_conn): mock_conn.create_graph_using_import_task.return_value = { "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, "status": "IMPORTING", } @@ -760,13 +820,17 @@ def test_create_graph_with_import_wait_for_completion(self, mock_hook_get_waiter ) result = operator.execute(None) - mock_hook_get_waiter.assert_called_once_with("graph_available") + # Should wait for both graph_available and import_task_successful + assert mock_hook_get_waiter.call_count == 2 + mock_hook_get_waiter.assert_any_call("graph_available") + mock_hook_get_waiter.assert_any_call("import_task_successful") assert result == {"graph_id": GRAPH_ID} @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_import_options_none_values_filtered(self, mock_conn): mock_conn.create_graph_using_import_task.return_value = { "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, "status": "IMPORTING", } @@ -787,6 +851,60 @@ def test_import_options_none_values_filtered(self, mock_conn): # importOptions should not be in call_args if all values are None assert "importOptions" not in call_args + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_defer_wait_for_task(self, mock_conn): + """Test that defer_wait_for_task defers with the import task trigger.""" + from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneImportTaskCompleteTrigger + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + waiter_delay=30, + waiter_max_attempts=60, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.defer_wait_for_task( + import_task_id=self.IMPORT_TASK_ID, + context=None, + event={"status": "success"}, + ) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneImportTaskCompleteTrigger) + assert exc_info.value.method_name == "execute_complete" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_deferrable_defers_with_graph_available_trigger(self, mock_conn): + """Test that execute defers with graph_available trigger and passes import_task_id.""" + from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneGraphAvailableTrigger + + mock_conn.create_graph_using_import_task.return_value = { + "graphId": GRAPH_ID, + "taskId": self.IMPORT_TASK_ID, + "status": "IMPORTING", + } + + operator = NeptuneCreateGraphWithImportOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"dimension": 128}, + source=SOURCE_S3_URI, + role_arn=ROLE_ARN, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneGraphAvailableTrigger) + assert exc_info.value.method_name == "defer_wait_for_task" + assert exc_info.value.kwargs == {"import_task_id": self.IMPORT_TASK_ID} + TASK_ID = "import-task-id-12345" @@ -810,7 +928,7 @@ def test_init_defaults(self, mock_conn): assert operator.graph_identifier == GRAPH_ID assert operator.role_arn == ROLE_ARN assert operator.source == SOURCE_S3_URI - assert operator.blank_node_handling == "convertToIri" + assert operator.blank_node_handling is None assert operator.fail_on_error is True assert operator.format is None assert operator.import_options is None @@ -825,7 +943,6 @@ def test_init_defaults(self, mock_conn): graphIdentifier=GRAPH_ID, roleArn=ROLE_ARN, source=SOURCE_S3_URI, - blankNodeHandling="convertToIri", failOnError=True, parquetType="COLUMNAR", ) @@ -910,7 +1027,7 @@ def test_start_import_wait_for_completion(self, mock_get_waiter, mock_conn): ) result = operator.execute(None) - mock_get_waiter.assert_called_once_with("import_task_completed") + mock_get_waiter.assert_called_once_with("import_task_successful") assert result == {"task_id": TASK_ID, "graph_id": GRAPH_ID} def test_execute_complete_success(self): @@ -953,7 +1070,7 @@ def test_init_defaults(self, mock_conn): task_identifier=TASK_ID, ) - assert operator.import_task_id == TASK_ID + assert operator.task_identifier == TASK_ID assert operator.wait_for_completion is True assert operator.waiter_delay == 30 assert operator.waiter_max_attempts == 60 @@ -977,7 +1094,7 @@ def test_init_custom_args(self, mock_conn): waiter_max_attempts=100, ) - assert operator.import_task_id == TASK_ID + assert operator.task_identifier == TASK_ID assert operator.waiter_delay == 60 assert operator.waiter_max_attempts == 100 From a5540d801b2b422dbf8e5c0174fc16d80a7cb32d Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Tue, 24 Mar 2026 14:02:40 +0000 Subject: [PATCH 12/28] Added additional links for import job and vpce --- providers/amazon/provider.yaml | 3 + .../airflow/providers/amazon/aws/links/ec2.py | 11 ++ .../amazon/aws/links/neptune_analytics.py | 42 +++++ .../amazon/aws/operators/neptune_analytics.py | 80 ++++++-- .../providers/amazon/get_provider_info.py | 3 + .../tests/unit/amazon/aws/links/test_ec2.py | 29 ++- .../aws/links/test_neptune_analytics.py | 54 ++++++ .../aws/operators/test_neptune_analytics.py | 176 +++++++++++++++++- .../aws/triggers/test_neptune_analytics.py | 16 +- 9 files changed, 378 insertions(+), 36 deletions(-) create mode 100644 providers/amazon/src/airflow/providers/amazon/aws/links/neptune_analytics.py create mode 100644 providers/amazon/tests/unit/amazon/aws/links/test_neptune_analytics.py diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index aa74974d6d834..1197b763ef8fb 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -985,6 +985,9 @@ extra-links: - airflow.providers.amazon.aws.links.datasync.DataSyncTaskExecutionLink - airflow.providers.amazon.aws.links.ec2.EC2InstanceLink - airflow.providers.amazon.aws.links.ec2.EC2InstanceDashboardLink + - airflow.providers.amazon.aws.links.neptune_analytics.NeptuneGraphLink + - airflow.providers.amazon.aws.links.neptune_analytics.NeptuneImportTaskLink + - airflow.providers.amazon.aws.links.ec2.VpcEndpointLink connection-types: - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py index 38a23956cddbb..96fb03e9130d4 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/ec2.py @@ -44,3 +44,14 @@ class EC2InstanceDashboardLink(BaseAwsLink): @staticmethod def format_instance_id_filter(instance_ids: list[str]) -> str: return ",:".join(instance_ids) + + +class VpcEndpointLink(BaseAwsLink): + """Helper class for constructing a VPC Endpoint link.""" + + name = "VPC Endpoint" + key = "_vpc_endpoint" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/vpcconsole/home?region={region_name}#EndpointDetails:vpcEndpointId={endpoint_id}" + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/links/neptune_analytics.py new file mode 100644 index 0000000000000..d3b13e48ad4fb --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/neptune_analytics.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class NeptuneGraphLink(BaseAwsLink): + """Helper class for constructing an Amazon Neptune Analytics Graph Link.""" + + name = "Neptune Graph" + key = "_neptune_graph" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/neptune/home?region={region_name}#analytics-graph-details:id={graph_id}" + + ";tab=connectivity" + ) + + +class NeptuneImportTaskLink(BaseAwsLink): + """Helper class for constructing an Amazon Neptune Analytics import task link.""" + + name = "Neptune Import Task" + key = "_import_task" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/neptune/home?region={region_name}#analytics-import-task-details:id={import_task_id}" + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 757453b54ae38..2a062cbb03ea4 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -24,6 +24,8 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.links.ec2 import VpcEndpointLink +from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink, NeptuneImportTaskLink from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.neptune_analytics import ( NeptuneGraphAvailableTrigger, @@ -78,7 +80,15 @@ class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): """ aws_hook_class = NeptuneAnalyticsHook - template_fields: Sequence[str] = aws_template_fields() + template_fields: Sequence[str] = aws_template_fields( + "graph_name", "vector_search_config", "provisioned_memory" + ) + + template_fields_renderers = { + "vector_search_config": "json", + } + + operator_extra_links = (NeptuneGraphLink(),) def __init__( self, @@ -113,7 +123,6 @@ def __init__( def execute(self, context: Context) -> dict: self.log.info("Creating graph %s", self.graph_name) - # TODO perform check create_params = { "graphName": self.graph_name, "vectorSearchConfiguration": self.vector_search_config, @@ -132,11 +141,24 @@ def execute(self, context: Context) -> dict: } response = self.hook.conn.create_graph(**create_params) - # TODO: need to return vpc id from response in case it wasn't supplied earlier. + self.log.info("Graph %s in status %s", self.graph_name, response.get("status", "Unknown")) self.graph_id = response.get("id", None) - # TODO build extra link to console + graph_url = NeptuneGraphLink.format_str.format( + graph_id=self.graph_id, + aws_domain=NeptuneGraphLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneGraphLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + graph_id=self.graph_id, + ) + self.log.info("You can view this Neptune Graph at : %s", graph_url) if self.deferrable: self.log.info("Deferring until graph %s is available", self.graph_id) @@ -276,6 +298,22 @@ def execute(self, context: Context) -> dict: ) endpoint_id = self._get_graph_endpoint_id() + + endpoint_url = VpcEndpointLink.format_str.format( + endpoint_id=endpoint_id, + aws_domain=VpcEndpointLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + VpcEndpointLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + endpoint_id=endpoint_id, + ) + self.log.info("You can view this private endpoint at : %s", endpoint_url) + return {"vpc_endpoint_id": endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} def _get_graph_endpoint_id(self): @@ -447,8 +485,8 @@ def execute(self, context: Context): # if not found, just exit because there is nothing to delete if e.response["Error"]["Code"] == "ResourceNotFoundException": self.log.info("Graph %s not found. Nothing to delete", self.graph_id) - else: - raise AirflowException(e.response["Error"]) + return + raise AirflowException(e.response["Error"]) if self.deferrable: self.log.info("Deferring until graph %s is deleted", self.graph_id) @@ -713,6 +751,7 @@ class NeptuneStartImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): template_fields_renderers = { "import_options": "json", } + operator_extra_links = (NeptuneImportTaskLink(),) def __init__( self, @@ -765,15 +804,30 @@ def execute(self, context: Context) -> dict: } response = self.hook.conn.start_import_task(**create_params) + import_task_id = response.get("taskId") + self.log.info("Import task %s started for graph %s", import_task_id, self.graph_identifier) - self.log.info("Import task %s started for graph %s", response.get("taskId"), self.graph_identifier) - task_id = response.get("taskId") + # Create the console link + import_task_url = NeptuneImportTaskLink.format_str.format( + import_task_id=import_task_id, + aws_domain=NeptuneImportTaskLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneImportTaskLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + import_task_id=import_task_id, + ) + self.log.info("You can view this import task at : %s", import_task_url) if self.deferrable: - self.log.info("Deferring until import task %s completes", task_id) + self.log.info("Deferring until import task %s completes", import_task_id) self.defer( trigger=NeptuneImportTaskCompleteTrigger( - task_id=task_id, + task_id=import_task_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, @@ -782,13 +836,13 @@ def execute(self, context: Context) -> dict: ) if self.wait_for_completion: - self.log.info("Waiting for import task %s to complete", task_id) + self.log.info("Waiting for import task %s to complete", import_task_id) self.hook.get_waiter("import_task_successful").wait( - taskIdentifier=task_id, + taskIdentifier=import_task_id, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - return {"task_id": task_id, "graph_id": self.graph_identifier} + return {"task_id": import_task_id, "graph_id": self.graph_identifier} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: task_id = "" diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index c4afec9768b88..3e2d59ef64f6c 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -1158,6 +1158,9 @@ def get_provider_info(): "airflow.providers.amazon.aws.links.datasync.DataSyncTaskExecutionLink", "airflow.providers.amazon.aws.links.ec2.EC2InstanceLink", "airflow.providers.amazon.aws.links.ec2.EC2InstanceDashboardLink", + "airflow.providers.amazon.aws.links.neptune_analytics.NeptuneGraphLink", + "airflow.providers.amazon.aws.links.neptune_analytics.NeptuneImportTaskLink", + "airflow.providers.amazon.aws.links.ec2.VpcEndpointLink", ], "connection-types": [ { diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py index ff2f48e9be174..9aa8e4ce65904 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.providers.amazon.aws.links.ec2 import EC2InstanceDashboardLink, EC2InstanceLink +from airflow.providers.amazon.aws.links.ec2 import EC2InstanceDashboardLink, EC2InstanceLink, VpcEndpointLink from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase @@ -85,3 +85,30 @@ def test_extra_link(self, mock_supervisor_comms): aws_partition="aws", instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS), ) + + +class TestVpcEndpointLink(BaseAwsLinksTestCase): + link_class = VpcEndpointLink + + ENDPOINT_ID = "vpce-0123456789abcdef0" + + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.send.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "endpoint_id": self.ENDPOINT_ID, + }, + ) + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/vpcconsole/home" + f"?region=us-east-1#EndpointDetails:vpcEndpointId={self.ENDPOINT_ID}" + ), + region_name="us-east-1", + aws_partition="aws", + endpoint_id=self.ENDPOINT_ID, + ) diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/links/test_neptune_analytics.py new file mode 100644 index 0000000000000..d332a916125d5 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/links/test_neptune_analytics.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + +pytestmark = pytest.mark.db_test + + +class TestNeptuneGraphLink(BaseAwsLinksTestCase): + link_class = NeptuneGraphLink + + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.send.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "graph_id": "g-fake123456", + }, + ) + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/neptune/home?region=us-east-1" + "#analytics-graph-details:id=g-fake123456;tab=connectivity" + ), + region_name="us-east-1", + aws_partition="aws", + graph_id="g-fake123456", + ) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index b31605ad15c82..3545f4e1a73c8 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -52,8 +52,27 @@ def hook() -> Generator[NeptuneAnalyticsHook, None, None]: class TestNeptuneCreateGraphOperator: + def test_template_fields(self): + # Verify template_fields includes the expected fields + fields = NeptuneCreateGraphOperator.template_fields + assert "graph_name" in fields + assert "vector_search_config" in fields + assert "provisioned_memory" in fields + + def test_template_fields_renderers(self): + assert NeptuneCreateGraphOperator.template_fields_renderers == {"vector_search_config": "json"} + + def test_operator_extra_links(self): + from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink + + assert len(NeptuneCreateGraphOperator.operator_extra_links) == 1 + assert isinstance(NeptuneCreateGraphOperator.operator_extra_links[0], NeptuneGraphLink) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_init_defaults(self, mock_conn): + def test_init_defaults(self, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + operator = NeptuneCreateGraphOperator( task_id="test_task", graph_name=GRAPH_NAME, @@ -76,8 +95,11 @@ def test_init_defaults(self, mock_conn): deletionProtection=False, ) + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_init_custom_args(self, mock_conn): + def test_init_custom_args(self, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + operator = NeptuneCreateGraphOperator( task_id="test_task", graph_name=GRAPH_NAME, @@ -109,9 +131,12 @@ def test_init_custom_args(self, mock_conn): tags={"key1": "test"}, ) + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") - def test_create_graph(self, mock_hook_get_waiter, mock_conn): + def test_create_graph(self, mock_hook_get_waiter, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + operator = NeptuneCreateGraphOperator( task_id="test_task", graph_name=GRAPH_NAME, @@ -123,11 +148,14 @@ def test_create_graph(self, mock_hook_get_waiter, mock_conn): mock_hook_get_waiter.assert_not_called() assert "graph_id" in resp - assert resp["graph_id"] is not None + assert resp["graph_id"] == GRAPH_ID + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") - def test_create_graph_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + def test_create_graph_wait_for_completion(self, mock_hook_get_waiter, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + operator = NeptuneCreateGraphOperator( task_id="test_task", graph_name=GRAPH_NAME, @@ -139,7 +167,57 @@ def test_create_graph_wait_for_completion(self, mock_hook_get_waiter, mock_conn) mock_hook_get_waiter.assert_called_once_with("graph_available") assert "graph_id" in resp - assert resp["graph_id"] is not None + assert resp["graph_id"] == GRAPH_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_persist_called_with_correct_args(self, mock_conn): + """Test that NeptuneGraphLink.persist is called with the correct arguments.""" + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + wait_for_completion=False, + ) + + mock_context = mock.MagicMock() + with mock.patch( + "airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist" + ) as mock_persist: + operator.execute(mock_context) + + mock_persist.assert_called_once_with( + context=mock_context, + operator=operator, + region_name=mock.ANY, + aws_partition=mock.ANY, + graph_id=GRAPH_ID, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_deferrable_defers_with_graph_available_trigger(self, mock_conn, mock_persist): + """Test that deferrable mode defers with NeptuneGraphAvailableTrigger.""" + from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneGraphAvailableTrigger + + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneGraphAvailableTrigger) + assert exc_info.value.method_name == "execute_complete" class TestNeptuneCreatePrivateGraphEndpointOperator: @@ -910,8 +988,25 @@ def test_deferrable_defers_with_graph_available_trigger(self, mock_conn): class TestNeptuneStartImportTaskOperator: + def test_template_fields(self): + fields = NeptuneStartImportTaskOperator.template_fields + assert "graph_identifier" in fields + assert "role_arn" in fields + assert "source" in fields + assert "import_options" in fields + + def test_template_fields_renderers(self): + assert NeptuneStartImportTaskOperator.template_fields_renderers == {"import_options": "json"} + + def test_operator_extra_links(self): + from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneImportTaskLink + + assert len(NeptuneStartImportTaskOperator.operator_extra_links) == 1 + assert isinstance(NeptuneStartImportTaskOperator.operator_extra_links[0], NeptuneImportTaskLink) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_init_defaults(self, mock_conn): + def test_init_defaults(self, mock_conn, mock_persist): mock_conn.start_import_task.return_value = { "taskId": TASK_ID, "graphId": GRAPH_ID, @@ -947,8 +1042,9 @@ def test_init_defaults(self, mock_conn): parquetType="COLUMNAR", ) + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_init_custom_args(self, mock_conn): + def test_init_custom_args(self, mock_conn, mock_persist): mock_conn.start_import_task.return_value = { "taskId": TASK_ID, "graphId": GRAPH_ID, @@ -988,9 +1084,10 @@ def test_init_custom_args(self, mock_conn): importOptions={"neptune.csv.allowEmptyStrings": True}, ) + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") - def test_start_import_no_wait(self, mock_get_waiter, mock_conn): + def test_start_import_no_wait(self, mock_get_waiter, mock_conn, mock_persist): mock_conn.start_import_task.return_value = { "taskId": TASK_ID, "graphId": GRAPH_ID, @@ -1009,9 +1106,10 @@ def test_start_import_no_wait(self, mock_get_waiter, mock_conn): mock_get_waiter.assert_not_called() assert result == {"task_id": TASK_ID, "graph_id": GRAPH_ID} + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") - def test_start_import_wait_for_completion(self, mock_get_waiter, mock_conn): + def test_start_import_wait_for_completion(self, mock_get_waiter, mock_conn, mock_persist): mock_conn.start_import_task.return_value = { "taskId": TASK_ID, "graphId": GRAPH_ID, @@ -1030,6 +1128,64 @@ def test_start_import_wait_for_completion(self, mock_get_waiter, mock_conn): mock_get_waiter.assert_called_once_with("import_task_successful") assert result == {"task_id": TASK_ID, "graph_id": GRAPH_ID} + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_persist_called_with_correct_args(self, mock_conn): + """Test that NeptuneImportTaskLink.persist is called with the correct arguments.""" + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + wait_for_completion=False, + ) + + mock_context = mock.MagicMock() + with mock.patch( + "airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist" + ) as mock_persist: + operator.execute(mock_context) + + mock_persist.assert_called_once_with( + context=mock_context, + operator=operator, + region_name=mock.ANY, + aws_partition=mock.ANY, + import_task_id=TASK_ID, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_deferrable_defers_with_import_task_trigger(self, mock_conn, mock_persist): + """Test that deferrable mode defers with NeptuneImportTaskCompleteTrigger.""" + from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneImportTaskCompleteTrigger + + mock_conn.start_import_task.return_value = { + "taskId": TASK_ID, + "graphId": GRAPH_ID, + "status": "IMPORTING", + } + + operator = NeptuneStartImportTaskOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + role_arn=ROLE_ARN, + source=SOURCE_S3_URI, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneImportTaskCompleteTrigger) + assert exc_info.value.method_name == "execute_complete" + def test_execute_complete_success(self): operator = NeptuneStartImportTaskOperator( task_id="test_task", diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index 637dd23453bc3..04d6cd99eb844 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -90,9 +90,7 @@ def test_serialization(self): Asserts that the NeptuneGraphPrivateEndpointAvailableTrigger correctly serializes its arguments and classpath. """ - trigger = NeptuneGraphPrivateEndpointAvailableTrigger( - graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID - ) + trigger = NeptuneGraphPrivateEndpointAvailableTrigger(graph_id=GRAPH_ID, vpc_id=VPC_ID) classpath, kwargs = trigger.serialize() assert ( classpath @@ -102,8 +100,6 @@ def test_serialization(self): assert kwargs["graph_id"] == GRAPH_ID assert "vpc_id" in kwargs assert kwargs["vpc_id"] == VPC_ID - assert "endpoint_id" in kwargs - assert kwargs["endpoint_id"] == ENDPOINT_ID @pytest.mark.asyncio @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") @@ -111,13 +107,11 @@ def test_serialization(self): async def test_run_success(self, mock_async_conn, mock_get_waiter): mock_async_conn.return_value.__aenter__.return_value = "AVAILABLE" mock_get_waiter().wait = AsyncMock() - trigger = NeptuneGraphPrivateEndpointAvailableTrigger( - graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID - ) + trigger = NeptuneGraphPrivateEndpointAvailableTrigger(graph_id=GRAPH_ID, vpc_id=VPC_ID) generator = trigger.run() resp = await generator.asend(None) - assert resp == TriggerEvent({"status": "success", "endpoint_id": ENDPOINT_ID}) + assert resp == TriggerEvent({"status": "success", "graph_id": GRAPH_ID}) assert mock_get_waiter().wait.call_count == 1 @pytest.mark.asyncio @@ -132,9 +126,7 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): ) mock_get_waiter.return_value.wait = wait_mock - trigger = NeptuneGraphPrivateEndpointAvailableTrigger( - graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID - ) + trigger = NeptuneGraphPrivateEndpointAvailableTrigger(graph_id=GRAPH_ID, vpc_id=VPC_ID) with pytest.raises(AirflowException): await trigger.run().asend(None) From 9524958b323ad2d7874f6d83b086ffb1d6b4a3e4 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Thu, 26 Mar 2026 13:20:29 +0000 Subject: [PATCH 13/28] Updated system and unit tests, and operator fixes --- .../amazon/aws/operators/neptune_analytics.py | 62 ++++++++----------- .../amazon/aws/triggers/neptune_analytics.py | 10 +-- .../amazon/aws/example_neptune_analytics.py | 10 +++ .../aws/operators/test_neptune_analytics.py | 36 +++++------ .../aws/triggers/test_neptune_analytics.py | 12 ++-- 5 files changed, 65 insertions(+), 65 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 2a062cbb03ea4..07ba592b7c2d0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -60,7 +60,7 @@ class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): :param deletion_protection: Indicates whether or not to enable deletion protection on the graph. The graph can't be deleted when deletion protection is enabled. :param kms_key_id: Specifies a KMS key to use to encrypt data in the new graph. - :param tags Specifies metadata tags to add to the graph. + :param tags: Specifies metadata tags to add to the graph. :param wait_for_completion: Whether to wait for the graph to start. (default: True) :param deferrable: If True, the operator will wait asynchronously for the graph to start. This implies waiting for completion. This mode requires aiobotocore module to be installed. @@ -73,7 +73,6 @@ class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. - :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html :return: dictionary with Neptune graph id and vpc id @@ -199,7 +198,6 @@ class NeptuneCreatePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalytics :param vpc_id: VPC to create endpoint in :param subnet_ids: Subnets in which private graph endpoint ENIs are created :param vpc_security_group_ids: Security groups to be attached to the private graph endpoint - :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. This implies waiting for completion. This mode requires aiobotocore module to be installed. @@ -212,7 +210,6 @@ class NeptuneCreatePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalytics empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. - :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html :return: dictionary with Neptune graph id @@ -274,8 +271,6 @@ def execute(self, context: Context) -> dict: # if VPC not provided, use the one that is returned, which is the default VPC. Required for the waiter self.vpc_id = result.get("vpcId", self.vpc_id) - # TODO extra link to console - if self.deferrable: self.log.info("Deferring until endpoint is available") self.defer( @@ -338,12 +333,9 @@ class NeptuneDeletePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalytics :ref:`howto/operator:NeptuneDeletePrivateGraphEndpointOperator` :param graph_identifier: Neptune Graph id - :param vpc_id: VPC to create endpoint in - :param subnet_ids: Subnets in which private graph endpoint ENIs are created - :param vpc_security_group_ids: Security groups to be attached to the private graph endpoint - - :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) - :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. + :param vpc_id: VPC where endpoint resides + :param wait_for_completion: Whether to wait for the endpoint to be deleted. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the endpoint to be deleted. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) :param waiter_delay: Time in seconds to wait between status checks. @@ -433,9 +425,9 @@ class NeptuneDeleteGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:NeptuneCreateGraphOperator` + :ref:`howto/operator:NeptuneDeleteGraphOperator` - :param graph_id: Name of Neptune graph to create + :param graph_id: Name of Neptune graph to delete :param skip_snapshot: Determines whether a final graph snapshot is created before the graph is deleted. If true is specified, no graph snapshot is created. If false is specified, a graph snapshot is created before the graph is deleted. :param wait_for_completion: Whether to wait for the graph to delete. (default: True) :param deferrable: If True, the operator will wait asynchronously for the graph to be deleted. @@ -697,7 +689,7 @@ def defer_wait_for_task( self.log.info("Deferring for import task %s completion", import_task_id) self.defer( trigger=NeptuneImportTaskCompleteTrigger( - task_id=import_task_id, + import_task_id=import_task_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, @@ -723,7 +715,7 @@ class NeptuneStartImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): :param source: URL identifying the source data location. :param blank_node_handling: Method to handle blank nodes in dataset. :param fail_on_error: If set to true, the task halts when an import error is encountered. If set to false, the task skips the data that caused the error and continues if possible. - :param format: Specifies the format of the Amazon S3 data to be ipmorted. + :param format: Specifies the format of the Amazon S3 data to be imported. :param import_options: Options on how to perform an import :param parquet_type: Parquet type of import task :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) @@ -738,7 +730,6 @@ class NeptuneStartImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. - :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html :return: dictionary with Neptune graph id @@ -827,7 +818,7 @@ def execute(self, context: Context) -> dict: self.log.info("Deferring until import task %s completes", import_task_id) self.defer( trigger=NeptuneImportTaskCompleteTrigger( - task_id=import_task_id, + import_task_id=import_task_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, @@ -842,14 +833,14 @@ def execute(self, context: Context) -> dict: WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - return {"task_id": import_task_id, "graph_id": self.graph_identifier} + return {"import_task_id": import_task_id, "graph_id": self.graph_identifier} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: task_id = "" if event: - task_id = event.get("task_id", "") + task_id = event.get("import_task_id", "") - return {"graph_id": self.graph_identifier, "task_id": task_id} + return {"graph_id": self.graph_identifier, "import_task_id": task_id} class NeptuneCancelImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -860,7 +851,7 @@ class NeptuneCancelImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:NeptuneCancelImportTaskOperator - :param task_identifier: Neptune Graph import task id to cancel. + :param import_task_id: Neptune Graph import task id to cancel. :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) :param deferrable: If True, the operator will wait asynchronously for the endpoint to become available. This implies waiting for completion. This mode requires aiobotocore module to be installed. @@ -873,7 +864,6 @@ class NeptuneCancelImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. - :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html :return: dictionary with Neptune graph id @@ -881,11 +871,11 @@ class NeptuneCancelImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): """ aws_hook_class = NeptuneAnalyticsHook - template_fields: Sequence[str] = aws_template_fields("task_identifier") + template_fields: Sequence[str] = aws_template_fields("import_task_id") def __init__( self, - task_identifier: str, + import_task_id: str, wait_for_completion: bool = True, waiter_delay: int = 30, waiter_max_attempts: int = 60, @@ -893,24 +883,24 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.task_identifier = task_identifier + self.import_task_id = import_task_id self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable def execute(self, context: Context) -> dict: - self.log.info("Cancelling import task %s", self.task_identifier) + self.log.info("Cancelling import task %s", self.import_task_id) - response = self.hook.conn.cancel_import_task(taskIdentifier=self.task_identifier) + response = self.hook.conn.cancel_import_task(taskIdentifier=self.import_task_id) - self.log.info("Import task %s status is %s", self.task_identifier, response.get("status", "Unknown")) + self.log.info("Import task %s status is %s", self.import_task_id, response.get("status", "Unknown")) if self.deferrable: - self.log.info("Deferring until import task %s is cancelled", self.task_identifier) + self.log.info("Deferring until import task %s is cancelled", self.import_task_id) self.defer( trigger=NeptuneImportTaskCancelledTrigger( - task_identifier=self.task_identifier, + import_task_id=self.import_task_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, @@ -919,18 +909,18 @@ def execute(self, context: Context) -> dict: ) if self.wait_for_completion: - self.log.info("Waiting for import task %s to be cancelled", self.task_identifier) + self.log.info("Waiting for import task %s to be cancelled", self.import_task_id) self.hook.get_waiter("import_task_cancelled").wait( - taskIdentifier=self.task_identifier, + taskIdentifier=self.import_task_id, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - return {"task_identifier": self.task_identifier} + return {"import_task_id": self.import_task_id} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: task_id = "" if event: - task_id = event.get("task_identifier", "") + task_id = event.get("import_task_id", "") self.log.info("Import task %s cancelled", task_id) - return {"task_identifier": task_id} + return {"import_task_id": task_id} diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py index 8deeb2d63faa2..7f0be3b7ae09c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -212,20 +212,20 @@ class NeptuneImportTaskCompleteTrigger(AwsBaseWaiterTrigger): def __init__( self, *, - task_id: str, + import_task_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, **kwargs, ) -> None: super().__init__( - serialized_fields={"task_id": task_id}, + serialized_fields={"import_task_id": import_task_id}, waiter_name="import_task_successful", - waiter_args={"taskIdentifier": task_id}, + waiter_args={"taskIdentifier": import_task_id}, failure_message="Import task failed", status_message="Status of import task is", status_queries=["status"], - return_key="task_id", - return_value=task_id, + return_key="import_task_id", + return_value=import_task_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, **kwargs, diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py index 2b5050c34b899..d55e3e0790418 100644 --- a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -24,6 +24,7 @@ from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook from airflow.providers.amazon.aws.operators.neptune_analytics import ( + NeptuneCancelImportTaskOperator, NeptuneCreateGraphOperator, NeptuneCreateGraphWithImportOperator, NeptuneCreatePrivateGraphEndpointOperator, @@ -228,6 +229,15 @@ def delete_graph_if_exists(graph_name: str) -> None: ) # [END howto_operator_neptune_analytics_start_import_task] + # [START howto_operator_neptune_analytics_cancel_import_task] + cancel_import = NeptuneCancelImportTaskOperator( + task_id="cancel_import", + import_task_id="{{ ti.xcom_pull(task_ids='start_import')['import_task_id']}}", + wait_for_completion=True, + aws_conn_id="aws_default", + ) + # [END howto_operator_neptune_analytics_cancel_import_task] + # [START howto_operator_neptune_analytics_delete_graph] delete_graph = NeptuneDeletePrivateGraphEndpointOperator( task_id="delete_graph", diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index 3545f4e1a73c8..cb10d29d3dee6 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -1104,7 +1104,7 @@ def test_start_import_no_wait(self, mock_get_waiter, mock_conn, mock_persist): result = operator.execute(None) mock_get_waiter.assert_not_called() - assert result == {"task_id": TASK_ID, "graph_id": GRAPH_ID} + assert result == {"import_task_id": TASK_ID, "graph_id": GRAPH_ID} @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneImportTaskLink.persist") @mock.patch.object(NeptuneAnalyticsHook, "conn") @@ -1126,7 +1126,7 @@ def test_start_import_wait_for_completion(self, mock_get_waiter, mock_conn, mock result = operator.execute(None) mock_get_waiter.assert_called_once_with("import_task_successful") - assert result == {"task_id": TASK_ID, "graph_id": GRAPH_ID} + assert result == {"import_task_id": TASK_ID, "graph_id": GRAPH_ID} @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_persist_called_with_correct_args(self, mock_conn): @@ -1194,10 +1194,10 @@ def test_execute_complete_success(self): source=SOURCE_S3_URI, ) - event = {"status": "success", "task_id": TASK_ID} + event = {"status": "success", "import_task_id": TASK_ID} result = operator.execute_complete(None, event) - assert result == {"graph_id": GRAPH_ID, "task_id": TASK_ID} + assert result == {"graph_id": GRAPH_ID, "import_task_id": TASK_ID} def test_execute_complete_no_event(self): operator = NeptuneStartImportTaskOperator( @@ -1209,7 +1209,7 @@ def test_execute_complete_no_event(self): result = operator.execute_complete(None, None) - assert result == {"graph_id": GRAPH_ID, "task_id": ""} + assert result == {"graph_id": GRAPH_ID, "import_task_id": ""} class TestNeptuneCancelImportTaskOperator: @@ -1223,10 +1223,10 @@ def test_init_defaults(self, mock_conn): operator = NeptuneCancelImportTaskOperator( task_id="test_task", - task_identifier=TASK_ID, + import_task_id=TASK_ID, ) - assert operator.task_identifier == TASK_ID + assert operator.import_task_id == TASK_ID assert operator.wait_for_completion is True assert operator.waiter_delay == 30 assert operator.waiter_max_attempts == 60 @@ -1245,12 +1245,12 @@ def test_init_custom_args(self, mock_conn): operator = NeptuneCancelImportTaskOperator( task_id="test_task", - task_identifier=TASK_ID, + import_task_id=TASK_ID, waiter_delay=60, waiter_max_attempts=100, ) - assert operator.task_identifier == TASK_ID + assert operator.import_task_id == TASK_ID assert operator.waiter_delay == 60 assert operator.waiter_max_attempts == 100 @@ -1265,13 +1265,13 @@ def test_cancel_no_wait(self, mock_get_waiter, mock_conn): operator = NeptuneCancelImportTaskOperator( task_id="test_task", - task_identifier=TASK_ID, + import_task_id=TASK_ID, wait_for_completion=False, ) result = operator.execute(None) mock_get_waiter.assert_not_called() - assert result == {"task_identifier": TASK_ID} + assert result == {"import_task_id": TASK_ID} @mock.patch.object(NeptuneAnalyticsHook, "conn") @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") @@ -1284,31 +1284,31 @@ def test_cancel_wait_for_completion(self, mock_get_waiter, mock_conn): operator = NeptuneCancelImportTaskOperator( task_id="test_task", - task_identifier=TASK_ID, + import_task_id=TASK_ID, wait_for_completion=True, ) result = operator.execute(None) mock_get_waiter.assert_called_once_with("import_task_cancelled") - assert result == {"task_identifier": TASK_ID} + assert result == {"import_task_id": TASK_ID} def test_execute_complete_success(self): operator = NeptuneCancelImportTaskOperator( task_id="test_task", - task_identifier=TASK_ID, + import_task_id=TASK_ID, ) - event = {"status": "success", "task_identifier": TASK_ID} + event = {"status": "success", "import_task_id": TASK_ID} result = operator.execute_complete(None, event) - assert result == {"task_identifier": TASK_ID} + assert result == {"import_task_id": TASK_ID} def test_execute_complete_no_event(self): operator = NeptuneCancelImportTaskOperator( task_id="test_task", - task_identifier=TASK_ID, + import_task_id=TASK_ID, ) result = operator.execute_complete(None, None) - assert result == {"task_identifier": ""} + assert result == {"import_task_id": ""} diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index 04d6cd99eb844..ddb85c4ddef58 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -194,14 +194,14 @@ def test_serialization(self): Asserts that the NeptuneImportTaskCompleteTrigger correctly serializes its arguments and classpath. """ - trigger = NeptuneImportTaskCompleteTrigger(task_id=TASK_ID) + trigger = NeptuneImportTaskCompleteTrigger(import_task_id=TASK_ID) classpath, kwargs = trigger.serialize() assert ( classpath == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneImportTaskCompleteTrigger" ) - assert "task_id" in kwargs - assert kwargs["task_id"] == TASK_ID + assert "import_task_id" in kwargs + assert kwargs["import_task_id"] == TASK_ID @pytest.mark.asyncio @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") @@ -209,11 +209,11 @@ def test_serialization(self): async def test_run_success(self, mock_async_conn, mock_get_waiter): mock_async_conn.return_value.__aenter__.return_value = "COMPLETED" mock_get_waiter().wait = AsyncMock() - trigger = NeptuneImportTaskCompleteTrigger(task_id=TASK_ID) + trigger = NeptuneImportTaskCompleteTrigger(import_task_id=TASK_ID) generator = trigger.run() resp = await generator.asend(None) - assert resp == TriggerEvent({"status": "success", "task_id": TASK_ID}) + assert resp == TriggerEvent({"status": "success", "import_task_id": TASK_ID}) assert mock_get_waiter().wait.call_count == 1 @pytest.mark.asyncio @@ -228,7 +228,7 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): ) mock_get_waiter.return_value.wait = wait_mock - trigger = NeptuneImportTaskCompleteTrigger(task_id=TASK_ID) + trigger = NeptuneImportTaskCompleteTrigger(import_task_id=TASK_ID) with pytest.raises(AirflowException): await trigger.run().asend(None) From 46fa89ff2a6999ee9960710b4484bd7bf0794264 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Thu, 26 Mar 2026 23:31:46 +0000 Subject: [PATCH 14/28] Added Neptune Analytics docs --- docs/spelling_wordlist.txt | 8 + .../docs/operators/neptune_analytics.rst | 148 ++++++++++++++++++ .../amazon/aws/operators/neptune_analytics.py | 2 +- .../amazon/aws/example_neptune_analytics.py | 13 +- 4 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 providers/amazon/docs/operators/neptune_analytics.rst diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index f98d6edde306e..e700fb5c408ba 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -304,6 +304,7 @@ ContainerPort contentUrl contextmgr contrib +convertToIri copyable CoreV coroutine @@ -439,6 +440,7 @@ deidentify DeidentifyTemplate del delim +deliverability deltalake denylist dep @@ -988,6 +990,7 @@ longblob lookups lshift lxml +m-NCUs machineTypes macOS mae @@ -1073,6 +1076,7 @@ nat natively nav navbar +NCUs nd ndjson nearText @@ -1101,9 +1105,11 @@ NotFound notificationChannels notin npm +nquads ns ntlm ntpd +ntriples Nullable nullable num @@ -1128,6 +1134,7 @@ Oozie OpenAI openai openapi +opencypher openfaas OpenID openlineage @@ -1321,6 +1328,7 @@ RaG RBAC rbac rc +rdfxml RDS rds readme diff --git a/providers/amazon/docs/operators/neptune_analytics.rst b/providers/amazon/docs/operators/neptune_analytics.rst new file mode 100644 index 0000000000000..5791dadbd631c --- /dev/null +++ b/providers/amazon/docs/operators/neptune_analytics.rst @@ -0,0 +1,148 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +======================== +Amazon Neptune Analytics +======================== + +`Amazon Neptune Analytics `__ is a memory-optimized graph database engine for analytics. With Neptune Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:NeptuneCreateGraphOperator: + +Create a new Neptune Graph +========================== + +To create a new Neptune Analytics Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCreateGraphOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_create_graph] + :end-before: [END howto_operator_neptune_analytics_create_graph] + + +.. _howto/operator:NeptuneDeleteGraphOperator: + +Delete a Neptune Graph +====================== + +To delete an existing Neptune Analytics Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneDeleteGraphOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_delete_graph] + :end-before: [END howto_operator_neptune_analytics_delete_graph] + +.. _howto/operator:NeptuneCreatePrivateGraphEndpointOperator: + +Create a Neptune Graph private endpoint +======================================= + +To create a VPC Endpoint for connecting to an existing Neptune Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCreatePrivateGraphEndpointOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_create_private_endpoint] + :end-before: [END howto_operator_neptune_analytics_create_private_endpoint] + +.. _howto/operator:NeptuneDeletePrivateGraphEndpointOperator: + +Delete a Neptune Graph private endpoint +======================================= + +To delete a VPC Endpoint attached to an existing Neptune Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneDeletePrivateGraphEndpointOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_delete_private_endpoint] + :end-before: [END howto_operator_neptune_analytics_delete_private_endpoint] + +.. _howto/operator:NeptuneCreateGraphWithImportOperator: + +Create a Neptune Graph with a data import task +============================================== + +To create a Neptune Analytics Graph and immediately import data, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCreateGraphWithImportOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_create_graph_with_import] + :end-before: [END howto_operator_neptune_analytics_create_graph_with_import] + +.. _howto/operator:NeptuneStartImportTaskOperator: + +Import data into an existing Neptune Graph +========================================== + +To import data into an existing Neptune Analytics Graph, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneStartImportTaskOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_start_import_task] + :end-before: [END howto_operator_neptune_analytics_start_import_task] + +.. _howto/operator:NeptuneCancelImportTaskOperator: + +Cancel a running import task +============================ + +To cancel an existing import task, you can use +:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCancelImportTaskOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_neptune_analytics.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_neptune_analytics_cancel_import_task] + :end-before: [END howto_operator_neptune_analytics_cancel_import_task] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 07ba592b7c2d0..bad7b1e35911d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -849,7 +849,7 @@ class NeptuneCancelImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:NeptuneCancelImportTaskOperator + :ref:`howto/operator:NeptuneCancelImportTaskOperator` :param import_task_id: Neptune Graph import task id to cancel. :param wait_for_completion: Whether to wait for the endpoint to be available. (default: True) diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py index d55e3e0790418..1f26d4990ce92 100644 --- a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -211,9 +211,20 @@ def delete_graph_if_exists(graph_name: str) -> None: graph_identifier="{{ ti.xcom_pull(task_ids='create_graph)['graph_id']}}", wait_for_completion=True, ) - # [END howto_operator_neptune_analytics_create_private_endpoint] + # [START howto_operator_neptune_analytics_delete_private_endpoint] + delete_endpoint = NeptuneDeletePrivateGraphEndpointOperator( + task_id="delete_endpoint", + graph_identifier="{{ ti.xcom_pull(task_ids='create_graph')['graph_id'] }}", + vpc_id="{{ ti.xcom_pull(task_ids='create_endpoint')['vpc_id'] }}", + wait_for_completion=True, + deferrable=False, + waiter_delay=30, + waiter_max_attempts=60, + ) + # [END howto_operator_neptune_analytics_delete_private_endpoint] + # [START howto_operator_neptune_analytics_start_import_task] start_import = NeptuneStartImportTaskOperator( task_id="start_import", From b2deb20a602add62f544473c807218a78e19db13 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Thu, 26 Mar 2026 23:40:35 +0000 Subject: [PATCH 15/28] added console links to NeptuneCreateGraphWithImportOperator --- .../amazon/aws/operators/neptune_analytics.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index bad7b1e35911d..7df4ed689170a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -651,8 +651,38 @@ def execute(self, context: Context) -> dict: self.graph_id = response.get("graphId", None) import_task_id = response.get("taskId") - # TODO build extra link to console - # TODO - second defer for task completion. + graph_url = NeptuneGraphLink.format_str.format( + graph_id=self.graph_id, + aws_domain=NeptuneGraphLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneGraphLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + graph_id=self.graph_id, + ) + + import_task_url = NeptuneImportTaskLink.format_str.format( + import_task_id=import_task_id, + aws_domain=NeptuneImportTaskLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + NeptuneImportTaskLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + import_task_id=import_task_id, + ) + + self.log.info("You can view this import task at : %s", import_task_url) + + self.log.info("You can view this Neptune Graph at : %s", graph_url) + if self.deferrable: self.log.info("Deferring until graph %s is available", self.graph_id) self.defer( From 7f8c3a892ac36c8279d603a35f0eb70b5fba1fa8 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Thu, 26 Mar 2026 23:43:29 +0000 Subject: [PATCH 16/28] Fixed missing operator_extra_links assignment --- .../providers/amazon/aws/operators/neptune_analytics.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 7df4ed689170a..4d48c4fabea35 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -560,6 +560,11 @@ class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook] "vector_search_config": "json", } + operator_extra_links = ( + NeptuneImportTaskLink(), + NeptuneGraphLink(), + ) + def __init__( self, graph_name: str, From 95c0b464cc00c12857d9993b4af38aefdabb469f Mon Sep 17 00:00:00 2001 From: ellisms <114107920+ellisms@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:14:00 -0400 Subject: [PATCH 17/28] Update providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../tests/unit/amazon/aws/triggers/test_neptune_analytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index ddb85c4ddef58..8002b95008332 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -41,7 +41,7 @@ class TestNeptuneGraphAvailableTrigger: def test_serialization(self): """ - Asserts that the TaskStateTrigger correctly serializes its arguments + Asserts that the NeptuneGraphAvailableTrigger correctly serializes its arguments and classpath. """ trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) From 71b9a27a25da75dc31a43a10df09d7cef909e166 Mon Sep 17 00:00:00 2001 From: ellisms <114107920+ellisms@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:15:12 -0400 Subject: [PATCH 18/28] Update providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../airflow/providers/amazon/aws/operators/neptune_analytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 4d48c4fabea35..aa0902d157079 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -503,7 +503,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if event: graph_id = event.get("graph_id", "Unknown") - self.log.info("Neptune graph % deleted", graph_id) + self.log.info("Neptune graph %s deleted", graph_id) class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook]): From 4c184866009219c1e3965d1bd08ad6508b204055 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Mon, 6 Apr 2026 19:57:52 +0000 Subject: [PATCH 19/28] Fixed issues found in CI --- .../amazon/aws/operators/neptune_analytics.py | 2 +- .../amazon/aws/example_neptune_analytics.py | 2 +- .../aws/operators/test_neptune_analytics.py | 27 +++++-------------- 3 files changed, 9 insertions(+), 22 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index aa0902d157079..9f360a33512ed 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -935,7 +935,7 @@ def execute(self, context: Context) -> dict: self.log.info("Deferring until import task %s is cancelled", self.import_task_id) self.defer( trigger=NeptuneImportTaskCancelledTrigger( - import_task_id=self.import_task_id, + task_identifier=self.import_task_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py index 1f26d4990ce92..1e257e844a49b 100644 --- a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -250,7 +250,7 @@ def delete_graph_if_exists(graph_name: str) -> None: # [END howto_operator_neptune_analytics_cancel_import_task] # [START howto_operator_neptune_analytics_delete_graph] - delete_graph = NeptuneDeletePrivateGraphEndpointOperator( + delete_graph = NeptuneDeleteGraphOperator( task_id="delete_graph", graph_id="{{ ti.xcom_pull(task_ids='create_graph')['graph_id'] }}", skip_snapshot=True, diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index cb10d29d3dee6..5f2d99c4877f2 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -17,13 +17,13 @@ # under the License. from __future__ import annotations -from collections.abc import Generator from unittest import mock import pytest -from moto import mock_aws +from botocore.exceptions import ClientError from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink from airflow.providers.amazon.aws.operators.neptune_analytics import ( NeptuneCancelImportTaskOperator, NeptuneCreateGraphOperator, @@ -33,7 +33,11 @@ NeptuneDeletePrivateGraphEndpointOperator, NeptuneStartImportTaskOperator, ) -from airflow.providers.common.compat.sdk import TaskDeferred +from airflow.providers.amazon.aws.triggers.neptune_analytics import ( + NeptuneGraphAvailableTrigger, + NeptuneImportTaskCompleteTrigger, +) +from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred GRAPH_NAME = "test_graph" GRAPH_ID = "test-graph-id" @@ -45,12 +49,6 @@ ROLE_ARN = "arn:aws:iam::123456789012:role/NeptuneImportRole" -@pytest.fixture -def hook() -> Generator[NeptuneAnalyticsHook, None, None]: - with mock_aws(): - yield NeptuneAnalyticsHook(aws_conn_id="aws_default") - - class TestNeptuneCreateGraphOperator: def test_template_fields(self): # Verify template_fields includes the expected fields @@ -63,7 +61,6 @@ def test_template_fields_renderers(self): assert NeptuneCreateGraphOperator.template_fields_renderers == {"vector_search_config": "json"} def test_operator_extra_links(self): - from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink assert len(NeptuneCreateGraphOperator.operator_extra_links) == 1 assert isinstance(NeptuneCreateGraphOperator.operator_extra_links[0], NeptuneGraphLink) @@ -200,8 +197,6 @@ def test_persist_called_with_correct_args(self, mock_conn): @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_deferrable_defers_with_graph_available_trigger(self, mock_conn, mock_persist): """Test that deferrable mode defers with NeptuneGraphAvailableTrigger.""" - from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneGraphAvailableTrigger - mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} operator = NeptuneCreateGraphOperator( @@ -373,7 +368,6 @@ def test_create_endpoint_sets_vpc_id_from_response(self, mock_conn): @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_create_endpoint_failed_status(self, mock_conn): - from airflow.providers.common.compat.sdk import AirflowException mock_conn.create_private_graph_endpoint.return_value = { "status": "FAILED", @@ -658,7 +652,6 @@ def test_delete_graph_wait_for_completion(self, mock_conn): @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_delete_graph_resource_not_found(self, mock_conn): - from botocore.exceptions import ClientError # Simulate ResourceNotFoundException error_response = { @@ -688,9 +681,6 @@ def test_delete_graph_resource_not_found(self, mock_conn): @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_delete_graph_other_client_error(self, mock_conn): - from botocore.exceptions import ClientError - - from airflow.providers.common.compat.sdk import AirflowException # Simulate other ClientError error_response = { @@ -932,7 +922,6 @@ def test_import_options_none_values_filtered(self, mock_conn): @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_defer_wait_for_task(self, mock_conn): """Test that defer_wait_for_task defers with the import task trigger.""" - from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneImportTaskCompleteTrigger operator = NeptuneCreateGraphWithImportOperator( task_id="test_task", @@ -958,7 +947,6 @@ def test_defer_wait_for_task(self, mock_conn): @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_deferrable_defers_with_graph_available_trigger(self, mock_conn): """Test that execute defers with graph_available trigger and passes import_task_id.""" - from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneGraphAvailableTrigger mock_conn.create_graph_using_import_task.return_value = { "graphId": GRAPH_ID, @@ -1163,7 +1151,6 @@ def test_persist_called_with_correct_args(self, mock_conn): @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_deferrable_defers_with_import_task_trigger(self, mock_conn, mock_persist): """Test that deferrable mode defers with NeptuneImportTaskCompleteTrigger.""" - from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneImportTaskCompleteTrigger mock_conn.start_import_task.return_value = { "taskId": TASK_ID, From ee4155f36da5871fc0214750b657fbe05a3a5ec2 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Tue, 7 Apr 2026 13:08:05 +0000 Subject: [PATCH 20/28] Fixed broken test and CI failures --- providers/amazon/provider.yaml | 6 ++++ .../providers/amazon/get_provider_info.py | 7 ++++ .../aws/triggers/test_neptune_analytics.py | 36 +++++++++++++------ 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 1197b763ef8fb..2e103c0144e22 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -402,6 +402,12 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-amazon/operators/mwaa.rst tags: [aws] + - integration-name: Amazon Neptune Analytics + external-doc-url: https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted.html + logo: /docs/integration-logos/Amazon-Neptune_64.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/neptune-analytics.rst + tags: [aws] - integration-name: Amazon S3 Vectors external-doc-url: https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors.html logo: /docs/integration-logos/Amazon-Simple-Storage-Service-S3_light-bg@4x.png diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 3e2d59ef64f6c..3d7d53b4f4111 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -379,6 +379,13 @@ def get_provider_info(): "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/s3_vectors.rst"], "tags": ["aws"], }, + { + "integration-name": "Amazon Neptune Analytics", + "external-doc-url": "https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted.html", + "logo": "/docs/integration-logos/Amazon-Neptune_64.png", + "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/neptune-analytics.rst"], + "tags": ["aws"], + }, ], "operators": [ { diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index 8002b95008332..f03693dd5a068 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -29,7 +29,6 @@ NeptuneImportTaskCancelledTrigger, NeptuneImportTaskCompleteTrigger, ) -from airflow.providers.common.compat.sdk import AirflowException from airflow.triggers.base import TriggerEvent GRAPH_ID = "test-graph" @@ -79,9 +78,12 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): mock_get_waiter.return_value.wait = wait_mock trigger = NeptuneGraphAvailableTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) - with pytest.raises(AirflowException): - await trigger.run().asend(None) + assert resp.payload["status"] == "error" + assert resp.payload["graph_id"] == GRAPH_ID + assert "Failed to create Neptune graph" in resp.payload["message"] class TestNeptuneGraphPrivateEndpointAvailableTrigger: @@ -127,9 +129,12 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): mock_get_waiter.return_value.wait = wait_mock trigger = NeptuneGraphPrivateEndpointAvailableTrigger(graph_id=GRAPH_ID, vpc_id=VPC_ID) + generator = trigger.run() + resp = await generator.asend(None) - with pytest.raises(AirflowException): - await trigger.run().asend(None) + assert resp.payload["status"] == "error" + assert resp.payload["graph_id"] == GRAPH_ID + assert "Failed to create Neptune graph endpoint" in resp.payload["message"] class TestNeptuneGraphPrivateEndpointDeletedTrigger: @@ -183,9 +188,12 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): trigger = NeptuneGraphPrivateEndpointDeletedTrigger( graph_id=GRAPH_ID, vpc_id=VPC_ID, endpoint_id=ENDPOINT_ID ) + generator = trigger.run() + resp = await generator.asend(None) - with pytest.raises(AirflowException): - await trigger.run().asend(None) + assert resp.payload["status"] == "error" + assert resp.payload["endpoint_id"] == ENDPOINT_ID + assert "Failed to delete Neptune graph endpoint" in resp.payload["message"] class TestNeptuneImportTaskCompleteTrigger: @@ -229,9 +237,12 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): mock_get_waiter.return_value.wait = wait_mock trigger = NeptuneImportTaskCompleteTrigger(import_task_id=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) - with pytest.raises(AirflowException): - await trigger.run().asend(None) + assert resp.payload["status"] == "error" + assert resp.payload["import_task_id"] == TASK_ID + assert "Import task failed" in resp.payload["message"] class TestNeptuneImportTaskCancelledTrigger: @@ -275,6 +286,9 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): mock_get_waiter.return_value.wait = wait_mock trigger = NeptuneImportTaskCancelledTrigger(task_identifier=TASK_ID) + generator = trigger.run() + resp = await generator.asend(None) - with pytest.raises(AirflowException): - await trigger.run().asend(None) + assert resp.payload["status"] == "error" + assert resp.payload["task_identifier"] == TASK_ID + assert "Import task cancellation failed" in resp.payload["message"] From 933a1234e89df11bff3207ba025cb49bd6ccb4e1 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Tue, 7 Apr 2026 17:06:57 +0000 Subject: [PATCH 21/28] Fixed url typo in provider file --- providers/amazon/provider.yaml | 2 +- .../amazon/src/airflow/providers/amazon/get_provider_info.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 2e103c0144e22..ae026bba5d72e 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -406,7 +406,7 @@ integrations: external-doc-url: https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted.html logo: /docs/integration-logos/Amazon-Neptune_64.png how-to-guide: - - /docs/apache-airflow-providers-amazon/operators/neptune-analytics.rst + - /docs/apache-airflow-providers-amazon/operators/neptune_analytics.rst tags: [aws] - integration-name: Amazon S3 Vectors external-doc-url: https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors.html diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 3d7d53b4f4111..117f602653078 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -383,7 +383,7 @@ def get_provider_info(): "integration-name": "Amazon Neptune Analytics", "external-doc-url": "https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted.html", "logo": "/docs/integration-logos/Amazon-Neptune_64.png", - "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/neptune-analytics.rst"], + "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/neptune_analytics.rst"], "tags": ["aws"], }, ], From e740ab6403ee754e861ed5636cee165431249dc9 Mon Sep 17 00:00:00 2001 From: mse139 Date: Sun, 12 Apr 2026 09:29:18 -0400 Subject: [PATCH 22/28] Requested PR changes --- .../amazon/aws/operators/neptune_analytics.py | 16 +++--- .../amazon/aws/triggers/neptune_analytics.py | 2 +- .../amazon/aws/example_neptune_analytics.py | 5 +- .../aws/triggers/test_neptune_analytics.py | 49 +++++++++++++++++++ 4 files changed, 64 insertions(+), 8 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 9f360a33512ed..0b94d1d940f2c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -22,7 +22,6 @@ from botocore.exceptions import ClientError -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook from airflow.providers.amazon.aws.links.ec2 import VpcEndpointLink from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink, NeptuneImportTaskLink @@ -36,7 +35,7 @@ NeptuneImportTaskCompleteTrigger, ) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.providers.common.compat.sdk import conf +from airflow.providers.common.compat.sdk import AirflowException, conf if TYPE_CHECKING: from airflow.sdk import Context @@ -75,7 +74,7 @@ class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html - :return: dictionary with Neptune graph id and vpc id + :return: dictionary with Neptune graph id """ aws_hook_class = NeptuneAnalyticsHook @@ -181,7 +180,7 @@ def execute(self, context: Context) -> dict: return {"graph_id": self.graph_id} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - self.log.info("Neptune graph % complete", self.graph_id) + self.log.info("Neptune graph %s complete", self.graph_id) return {"graph_id": self.graph_id} @@ -320,6 +319,9 @@ def _get_graph_endpoint_id(self): def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: vpc_endpoint_id = self._get_graph_endpoint_id() + status = event["status"] + if status.lower() != "success": + raise AirflowException("Endpoint failed to create") return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} @@ -411,10 +413,12 @@ def execute(self, context: Context) -> None: def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: vpc_endpoint_id = "" - if event and event.get("status") == "success": + if event and event.get("status").lower() == "success": vpc_endpoint_id = event.get("endpoint_id", "Unknown") self.log.info("Endpoint id %s deleted", vpc_endpoint_id) + else: + raise AirflowException("Endpoint failed to delete.") class NeptuneDeleteGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -718,7 +722,7 @@ def execute(self, context: Context) -> dict: return {"graph_id": self.graph_id} def defer_wait_for_task( - self, import_task_id: str, context: Context, event: dict[str, Any] | None = None + self, context: Context, event: dict[str, Any] | None = None, import_task_id: str | None = None ) -> None: """Defers for import task completion.""" self.log.info("Deferring for import task %s completion", import_task_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py index 7f0be3b7ae09c..50a4d1edc5abe 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/neptune_analytics.py @@ -266,7 +266,7 @@ def __init__( failure_message="Import task cancellation failed", status_message="Status of import task is", status_queries=["status"], - return_key="task_identifier", + return_key="import_task_id", return_value=task_identifier, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py index 1e257e844a49b..ca24c4e194ac3 100644 --- a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -208,7 +208,7 @@ def delete_graph_if_exists(graph_name: str) -> None: # [START howto_operator_neptune_analytics_create_private_endpoint] create_endpoint = NeptuneCreatePrivateGraphEndpointOperator( task_id="create_endpoint", - graph_identifier="{{ ti.xcom_pull(task_ids='create_graph)['graph_id']}}", + graph_identifier="{{ ti.xcom_pull(task_ids='create_graph')['graph_id']}}", wait_for_completion=True, ) # [END howto_operator_neptune_analytics_create_private_endpoint] @@ -317,7 +317,10 @@ def delete_graph_if_exists(graph_name: str) -> None: create_role, # TEST BODY: Create graph, import data, then delete create_graph, + create_endpoint, start_import, + cancel_import, + delete_endpoint, delete_graph, # TEST BODY: Create graph with import, then delete create_graph_with_import, diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index f03693dd5a068..cf23ebaa5a574 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -24,6 +24,7 @@ from airflow.providers.amazon.aws.triggers.neptune_analytics import ( NeptuneGraphAvailableTrigger, + NeptuneGraphDeletedTrigger, NeptuneGraphPrivateEndpointAvailableTrigger, NeptuneGraphPrivateEndpointDeletedTrigger, NeptuneImportTaskCancelledTrigger, @@ -292,3 +293,51 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): assert resp.payload["status"] == "error" assert resp.payload["task_identifier"] == TASK_ID assert "Import task cancellation failed" in resp.payload["message"] + + +class TestNeptuneGraphDeletedTrigger: + def test_serialization(self): + """ + Asserts that the NeptuneGraphDeletedTrigger correctly serializes its arguments + and classpath. + """ + trigger = NeptuneGraphDeletedTrigger(graph_id=GRAPH_ID) + classpath, kwargs = trigger.serialize() + assert ( + classpath == "airflow.providers.amazon.aws.triggers.neptune_analytics.NeptuneGraphDeletedTrigger" + ) + assert "graph_id" in kwargs + assert kwargs["graph_id"] == GRAPH_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = "DELETED" + mock_get_waiter().wait = AsyncMock() + trigger = NeptuneGraphDeletedTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp == TriggerEvent({"status": "success", "graph_id": GRAPH_ID}) + assert mock_get_waiter().wait.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.neptune_analytics.NeptuneAnalyticsHook.get_async_conn") + async def test_run_failure(self, mock_async_conn, mock_get_waiter): + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError( + name="graph_deleted", + reason='Waiter encountered a terminal failure state: For expression "status" we matched expected path: "FAILED"', + last_response={"status": "FAILED", "graphIdentifier": GRAPH_ID}, + ) + mock_get_waiter.return_value.wait = wait_mock + + trigger = NeptuneGraphDeletedTrigger(graph_id=GRAPH_ID) + generator = trigger.run() + resp = await generator.asend(None) + + assert resp.payload["status"] == "error" + assert resp.payload["graph_id"] == GRAPH_ID + assert "Failed to delete Neptune graph" in resp.payload["message"] From 054e05b3c5c2baf354c0351b7c7ee00a264915c7 Mon Sep 17 00:00:00 2001 From: mse139 Date: Sat, 25 Apr 2026 09:39:16 -0400 Subject: [PATCH 23/28] Added Neptune Analytics exceptions and addressed PR comments --- .../docs/operators/neptune_analytics.rst | 14 ++--- .../providers/amazon/aws/exceptions.py | 20 ++++++ .../amazon/aws/operators/neptune_analytics.py | 63 +++++++++++++++---- .../aws/operators/test_neptune_analytics.py | 45 +++++++------ .../aws/triggers/test_neptune_analytics.py | 4 +- 5 files changed, 107 insertions(+), 39 deletions(-) diff --git a/providers/amazon/docs/operators/neptune_analytics.rst b/providers/amazon/docs/operators/neptune_analytics.rst index 5791dadbd631c..ceaefda13a2d0 100644 --- a/providers/amazon/docs/operators/neptune_analytics.rst +++ b/providers/amazon/docs/operators/neptune_analytics.rst @@ -40,7 +40,7 @@ Create a new Neptune Graph ========================== To create a new Neptune Analytics Graph, you can use -:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCreateGraphOperator`. +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCreateGraphOperator`. This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires the aiobotocore module to be installed. @@ -57,7 +57,7 @@ Delete a Neptune Graph ====================== To delete an existing Neptune Analytics Graph, you can use -:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneDeleteGraphOperator`. +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneDeleteGraphOperator`. This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires the aiobotocore module to be installed. @@ -73,7 +73,7 @@ Create a Neptune Graph private endpoint ======================================= To create a VPC Endpoint for connecting to an existing Neptune Graph, you can use -:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCreatePrivateGraphEndpointOperator`. +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCreatePrivateGraphEndpointOperator`. This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires the aiobotocore module to be installed. @@ -89,7 +89,7 @@ Delete a Neptune Graph private endpoint ======================================= To delete a VPC Endpoint attached to an existing Neptune Graph, you can use -:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneDeletePrivateGraphEndpointOperator`. +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneDeletePrivateGraphEndpointOperator`. This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires the aiobotocore module to be installed. @@ -105,7 +105,7 @@ Create a Neptune Graph with a data import task ============================================== To create a Neptune Analytics Graph and immediately import data, you can use -:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCreateGraphWithImportOperator`. +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCreateGraphWithImportOperator`. This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires the aiobotocore module to be installed. @@ -121,7 +121,7 @@ Import data into an existing Neptune Graph ========================================== To import data into an existing Neptune Analytics Graph, you can use -:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneStartImportTaskOperator`. +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneStartImportTaskOperator`. This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires the aiobotocore module to be installed. @@ -137,7 +137,7 @@ Cancel a running import task ============================ To cancel an existing import task, you can use -:class:`~airflow.providers.amazon.aws.operators.neptune.NeptuneCancelImportTaskOperator`. +:class:`~airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneCancelImportTaskOperator`. This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires the aiobotocore module to be installed. diff --git a/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py b/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py index ae1863bec9c04..8cb521edd9521 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py @@ -50,3 +50,23 @@ def __reduce__(self): class S3HookUriParseFailure(AirflowException): """When parse_s3_url fails to parse URL, this error is thrown.""" + + +class NeptuneGraphCreationFailedError(AirflowException): + """Raised when a Neptune Analytics graph fails to reach the available state.""" + + +class NeptunePrivateEndpointCreationFailedError(AirflowException): + """Raised when a Neptune Analytics private graph endpoint fails to be created.""" + + +class NeptunePrivateEndpointDeletionFailedError(AirflowException): + """Raised when a Neptune Analytics private graph endpoint fails to be deleted.""" + + +class NeptuneGraphDeletionFailedError(AirflowException): + """Raised when a Neptune Analytics graph deletion encounters an unexpected AWS error.""" + + +class NeptuneImportTaskCancellationFailedError(AirflowException): + """Raised when a Neptune Analytics import task cancellation fails or returns an unexpected status.""" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 0b94d1d940f2c..fa4495b4e21d5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -35,11 +35,19 @@ NeptuneImportTaskCompleteTrigger, ) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.providers.common.compat.sdk import AirflowException, conf +from airflow.providers.common.compat.sdk import conf if TYPE_CHECKING: from airflow.sdk import Context +from airflow.providers.amazon.aws.exceptions import ( + NeptuneGraphCreationFailedError, + NeptuneGraphDeletionFailedError, + NeptuneImportTaskCancellationFailedError, + NeptunePrivateEndpointCreationFailedError, + NeptunePrivateEndpointDeletionFailedError, +) + class NeptuneCreateGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): """ @@ -180,6 +188,19 @@ def execute(self, context: Context) -> dict: return {"graph_id": self.graph_id} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + if event is None: + raise NeptuneGraphCreationFailedError( + "No event received while waiting for Neptune graph creation to complete." + ) + + if event.get("status") != "success": + raise NeptuneGraphCreationFailedError( + event.get( + "message", + f"Neptune graph {self.graph_id} creation did not complete successfully: {event}", + ) + ) + self.log.info("Neptune graph %s complete", self.graph_id) return {"graph_id": self.graph_id} @@ -265,7 +286,9 @@ def execute(self, context: Context) -> dict: self.log.info("Status of endpoint: %s", status) if status in ["FAILED"]: - raise AirflowException(f"Private endpoint failed to create for graph {self.graph_identifier}") + raise NeptunePrivateEndpointCreationFailedError( + f"Private endpoint failed to create for graph {self.graph_identifier}" + ) # if VPC not provided, use the one that is returned, which is the default VPC. Required for the waiter self.vpc_id = result.get("vpcId", self.vpc_id) @@ -321,7 +344,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None vpc_endpoint_id = self._get_graph_endpoint_id() status = event["status"] if status.lower() != "success": - raise AirflowException("Endpoint failed to create") + raise NeptunePrivateEndpointCreationFailedError("Endpoint failed to create") return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} @@ -386,7 +409,9 @@ def execute(self, context: Context) -> None: endpoint_id = result.get("vpcEndpointId") if status == "FAILED": - raise AirflowException(f"Failed to delete private endpoint {endpoint_id}") + raise NeptunePrivateEndpointDeletionFailedError( + f"Failed to delete private endpoint {endpoint_id}" + ) if self.deferrable: self.log.info("Deferring until endpoint %s is deleted", endpoint_id) @@ -418,7 +443,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None self.log.info("Endpoint id %s deleted", vpc_endpoint_id) else: - raise AirflowException("Endpoint failed to delete.") + raise NeptunePrivateEndpointDeletionFailedError("Endpoint failed to delete.") class NeptuneDeleteGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -482,7 +507,7 @@ def execute(self, context: Context): if e.response["Error"]["Code"] == "ResourceNotFoundException": self.log.info("Graph %s not found. Nothing to delete", self.graph_id) return - raise AirflowException(e.response["Error"]) + raise NeptuneGraphDeletionFailedError(e.response["Error"]) if self.deferrable: self.log.info("Deferring until graph %s is deleted", self.graph_id) @@ -498,7 +523,7 @@ def execute(self, context: Context): if self.wait_for_completion: self.log.info("Waiting to delete %s", self.graph_id) - self.hook.conn.get_waiter("graph_deleted").wait( + self.hook.get_waiter("graph_deleted").wait( graphIdentifier=self.graph_id, WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) @@ -552,7 +577,7 @@ class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook] :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html - :return: dictionary with Neptune graph id and vpc id + :return: dictionary with Neptune graph id """ aws_hook_class = NeptuneAnalyticsHook @@ -737,6 +762,11 @@ def defer_wait_for_task( ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + if event is None or event.get("status") != "success": + message = (event or {}).get( + "message", f"Neptune graph {self.graph_id} import did not complete successfully" + ) + raise NeptuneGraphCreationFailedError(message) self.log.info("Import complete for graph %s", self.graph_id) return {"graph_id": self.graph_id} @@ -957,9 +987,18 @@ def execute(self, context: Context) -> dict: return {"import_task_id": self.import_task_id} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - task_id = "" - if event: - task_id = event.get("import_task_id", "") - self.log.info("Import task %s cancelled", task_id) + if event is None: + raise NeptuneImportTaskCancellationFailedError( + "No event received while waiting for Neptune import task cancellation." + ) + + status = str(event.get("status", "")).lower() + if status not in {"success", "cancelled", "canceled"}: + message = event.get("message") or event.get("error") or f"Unexpected trigger status: {status!r}" + raise NeptuneImportTaskCancellationFailedError( + f"Error while waiting for Neptune import task cancellation: {message}" + ) + task_id = event.get("import_task_id", "") + self.log.info("Import task %s cancelled", task_id) return {"import_task_id": task_id} diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index 5f2d99c4877f2..295b122863d6c 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -22,8 +22,14 @@ import pytest from botocore.exceptions import ClientError +from airflow.providers.amazon.aws.exceptions import ( + NeptuneGraphDeletionFailedError, + NeptuneImportTaskCancellationFailedError, + NeptunePrivateEndpointCreationFailedError, + NeptunePrivateEndpointDeletionFailedError, +) from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook -from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink +from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink, NeptuneImportTaskLink from airflow.providers.amazon.aws.operators.neptune_analytics import ( NeptuneCancelImportTaskOperator, NeptuneCreateGraphOperator, @@ -37,7 +43,7 @@ NeptuneGraphAvailableTrigger, NeptuneImportTaskCompleteTrigger, ) -from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred +from airflow.providers.common.compat.sdk import TaskDeferred GRAPH_NAME = "test_graph" GRAPH_ID = "test-graph-id" @@ -381,7 +387,10 @@ def test_create_endpoint_failed_status(self, mock_conn): vpc_id=VPC_ID, ) - with pytest.raises(AirflowException, match=f"Private endpoint failed to create for graph {GRAPH_ID}"): + with pytest.raises( + NeptunePrivateEndpointCreationFailedError, + match=f"Private endpoint failed to create for graph {GRAPH_ID}", + ): operator.execute(None) @mock.patch.object(NeptuneAnalyticsHook, "conn") @@ -517,8 +526,6 @@ def test_delete_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_co @mock.patch.object(NeptuneAnalyticsHook, "conn") def test_delete_endpoint_failed_status(self, mock_conn): - from airflow.providers.common.compat.sdk import AirflowException - mock_conn.delete_private_graph_endpoint.return_value = { "status": "FAILED", "vpcEndpointId": ENDPOINT_ID, @@ -531,7 +538,10 @@ def test_delete_endpoint_failed_status(self, mock_conn): vpc_id=VPC_ID, ) - with pytest.raises(AirflowException, match=f"Failed to delete private endpoint {ENDPOINT_ID}"): + with pytest.raises( + NeptunePrivateEndpointDeletionFailedError, + match=f"Failed to delete private endpoint {ENDPOINT_ID}", + ): operator.execute(None) def test_execute_complete_success(self): @@ -608,7 +618,8 @@ def test_init_custom_args(self, mock_conn): ) @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_delete_graph_no_wait(self, mock_conn): + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_graph_no_wait(self, mock_get_waiter, mock_conn): mock_conn.delete_graph.return_value = { "id": GRAPH_ID, "name": GRAPH_NAME, @@ -624,17 +635,18 @@ def test_delete_graph_no_wait(self, mock_conn): operator.execute(None) mock_conn.delete_graph.assert_called_once() - mock_conn.get_waiter.assert_not_called() + mock_get_waiter.assert_not_called() @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_delete_graph_wait_for_completion(self, mock_conn): + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_graph_wait_for_completion(self, mock_get_waiter, mock_conn): mock_conn.delete_graph.return_value = { "id": GRAPH_ID, "name": GRAPH_NAME, "status": "DELETING", } mock_waiter = mock.MagicMock() - mock_conn.get_waiter.return_value = mock_waiter + mock_get_waiter.return_value = mock_waiter operator = NeptuneDeleteGraphOperator( task_id="test_task", @@ -644,7 +656,7 @@ def test_delete_graph_wait_for_completion(self, mock_conn): ) operator.execute(None) - mock_conn.get_waiter.assert_called_once_with("graph_deleted") + mock_get_waiter.assert_called_once_with("graph_deleted") mock_waiter.wait.assert_called_once_with( graphIdentifier=GRAPH_ID, WaiterConfig={"Delay": 30, "MaxAttempts": 60}, @@ -700,8 +712,8 @@ def test_delete_graph_other_client_error(self, mock_conn): skip_snapshot=True, ) - # Should raise AirflowException for non-ResourceNotFoundException errors - with pytest.raises(AirflowException): + # Should raise NeptuneGraphDeletionFailedError for non-ResourceNotFoundException errors + with pytest.raises(NeptuneGraphDeletionFailedError): operator.execute(None) @@ -987,8 +999,6 @@ def test_template_fields_renderers(self): assert NeptuneStartImportTaskOperator.template_fields_renderers == {"import_options": "json"} def test_operator_extra_links(self): - from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneImportTaskLink - assert len(NeptuneStartImportTaskOperator.operator_extra_links) == 1 assert isinstance(NeptuneStartImportTaskOperator.operator_extra_links[0], NeptuneImportTaskLink) @@ -1296,6 +1306,5 @@ def test_execute_complete_no_event(self): import_task_id=TASK_ID, ) - result = operator.execute_complete(None, None) - - assert result == {"import_task_id": ""} + with pytest.raises(NeptuneImportTaskCancellationFailedError): + operator.execute_complete(None, None) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py index cf23ebaa5a574..9232a9582322b 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune_analytics.py @@ -271,7 +271,7 @@ async def test_run_success(self, mock_async_conn, mock_get_waiter): generator = trigger.run() resp = await generator.asend(None) - assert resp == TriggerEvent({"status": "success", "task_identifier": TASK_ID}) + assert resp == TriggerEvent({"status": "success", "import_task_id": TASK_ID}) assert mock_get_waiter().wait.call_count == 1 @pytest.mark.asyncio @@ -291,7 +291,7 @@ async def test_run_failure(self, mock_async_conn, mock_get_waiter): resp = await generator.asend(None) assert resp.payload["status"] == "error" - assert resp.payload["task_identifier"] == TASK_ID + assert resp.payload["import_task_id"] == TASK_ID assert "Import task cancellation failed" in resp.payload["message"] From 8c466bb5790f5b7374093da8837bf84a74ab76a5 Mon Sep 17 00:00:00 2001 From: mse139 Date: Sun, 26 Apr 2026 13:41:30 -0400 Subject: [PATCH 24/28] Fixed mypy errors --- .../amazon/aws/operators/neptune_analytics.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index fa4495b4e21d5..cf42c16f617e7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -342,9 +342,10 @@ def _get_graph_endpoint_id(self): def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: vpc_endpoint_id = self._get_graph_endpoint_id() - status = event["status"] - if status.lower() != "success": - raise NeptunePrivateEndpointCreationFailedError("Endpoint failed to create") + if event: + status = event["status"] + if status.lower() != "success": + raise NeptunePrivateEndpointCreationFailedError("Endpoint failed to create") return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} @@ -750,16 +751,17 @@ def defer_wait_for_task( self, context: Context, event: dict[str, Any] | None = None, import_task_id: str | None = None ) -> None: """Defers for import task completion.""" - self.log.info("Deferring for import task %s completion", import_task_id) - self.defer( - trigger=NeptuneImportTaskCompleteTrigger( - import_task_id=import_task_id, - waiter_delay=self.waiter_delay, - waiter_max_attempts=self.waiter_max_attempts, - aws_conn_id=self.aws_conn_id, - ), - method_name="execute_complete", - ) + if import_task_id: + self.log.info("Deferring for import task %s completion", import_task_id) + self.defer( + trigger=NeptuneImportTaskCompleteTrigger( + import_task_id=import_task_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: if event is None or event.get("status") != "success": From 71c5db48d16f151ff3de54b9b7a5e85c03d7a8fe Mon Sep 17 00:00:00 2001 From: mse139 Date: Sat, 2 May 2026 06:09:39 -0400 Subject: [PATCH 25/28] provider update --- .../airflow/providers/amazon/get_provider_info.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 117f602653078..3d158a94523b0 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -372,13 +372,6 @@ def get_provider_info(): "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/mwaa.rst"], "tags": ["aws"], }, - { - "integration-name": "Amazon S3 Vectors", - "external-doc-url": "https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors.html", - "logo": "/docs/integration-logos/Amazon-Simple-Storage-Service-S3_light-bg@4x.png", - "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/s3_vectors.rst"], - "tags": ["aws"], - }, { "integration-name": "Amazon Neptune Analytics", "external-doc-url": "https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted.html", @@ -386,6 +379,13 @@ def get_provider_info(): "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/neptune_analytics.rst"], "tags": ["aws"], }, + { + "integration-name": "Amazon S3 Vectors", + "external-doc-url": "https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors.html", + "logo": "/docs/integration-logos/Amazon-Simple-Storage-Service-S3_light-bg@4x.png", + "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/s3_vectors.rst"], + "tags": ["aws"], + }, ], "operators": [ { From d2cf9b387a3ec7ac7f10cc4e7432d23ff52f51ed Mon Sep 17 00:00:00 2001 From: mse139 Date: Sat, 2 May 2026 07:50:58 -0400 Subject: [PATCH 26/28] Updated system test exception handling based on prek findings --- .../amazon/aws/example_neptune_analytics.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py index ca24c4e194ac3..3768ec0f407c7 100644 --- a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import contextlib import json import time from datetime import datetime @@ -111,21 +112,15 @@ def create_neptune_import_role(role_name: str) -> str: @task(trigger_rule=TriggerRule.ALL_DONE) def delete_neptune_import_role(role_name: str) -> None: iam_client = boto3.client("iam") - try: + with contextlib.suppress(iam_client.exceptions.NoSuchEntityException): iam_client.delete_role_policy(RoleName=role_name, PolicyName="NeptuneAnalyticsS3Access") - except iam_client.exceptions.NoSuchEntityException: - pass - try: - iam_client.delete_role(RoleName=role_name) - except iam_client.exceptions.NoSuchEntityException: - pass @task(trigger_rule=TriggerRule.ALL_DONE) def delete_graph_if_exists(graph_name: str) -> None: """Safety net to clean up the graph in case a previous task failed.""" hook = NeptuneAnalyticsHook() - try: + with contextlib.suppress(Exception): # List graphs and find by name paginator = hook.conn.get_paginator("list_graphs") for page in paginator.paginate(): @@ -133,18 +128,15 @@ def delete_graph_if_exists(graph_name: str) -> None: if graph.get("name") == graph_name: graph_id = graph["id"] # Disable deletion protection if enabled - try: - hook.conn.update_graph(graphIdentifier=graph_id, deletionProtection=False) - except Exception: - pass + + hook.conn.update_graph(graphIdentifier=graph_id, deletionProtection=False) + hook.conn.delete_graph(graphIdentifier=graph_id, skipSnapshot=True) hook.conn.get_waiter("graph_deleted").wait( graphIdentifier=graph_id, WaiterConfig={"Delay": 30, "MaxAttempts": 60}, ) return - except Exception: - pass with DAG( From a4b2f3d8d9b6b668ec22fdf43add9a9b1605cfc6 Mon Sep 17 00:00:00 2001 From: mse139 Date: Sat, 2 May 2026 11:59:56 -0400 Subject: [PATCH 27/28] Fixed calling lower() on None --- .../airflow/providers/amazon/aws/operators/neptune_analytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index cf42c16f617e7..23cb6925798fd 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -439,7 +439,7 @@ def execute(self, context: Context) -> None: def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: vpc_endpoint_id = "" - if event and event.get("status").lower() == "success": + if event and (event.get("status") or "").lower() == "success": vpc_endpoint_id = event.get("endpoint_id", "Unknown") self.log.info("Endpoint id %s deleted", vpc_endpoint_id) From 42cb361542da56dd64923dfaff56d7e258b29865 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Fri, 22 May 2026 18:01:01 +0000 Subject: [PATCH 28/28] Added custom import waiter and addressed PR suggestions --- .../providers/amazon/aws/exceptions.py | 4 + .../amazon/aws/hooks/neptune_analytics.py | 5 + .../amazon/aws/operators/neptune_analytics.py | 159 ++++++++++-------- .../amazon/aws/waiters/neptune_analytics.json | 38 +++++ .../amazon/aws/example_neptune_analytics.py | 25 ++- .../aws/hooks/test_neptune_analytics.py | 28 ++- .../aws/operators/test_neptune_analytics.py | 67 ++------ 7 files changed, 193 insertions(+), 133 deletions(-) create mode 100644 providers/amazon/src/airflow/providers/amazon/aws/waiters/neptune_analytics.json diff --git a/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py b/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py index 8cb521edd9521..e289099d831a5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/exceptions.py @@ -70,3 +70,7 @@ class NeptuneGraphDeletionFailedError(AirflowException): class NeptuneImportTaskCancellationFailedError(AirflowException): """Raised when a Neptune Analytics import task cancellation fails or returns an unexpected status.""" + + +class NeptuneImportTaskFailedError(AirflowException): + """Raised when a Neptune Analytics import task fails to complete successfully.""" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py index 252878fa11b4b..8f079b9194b2c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/neptune_analytics.py @@ -35,3 +35,8 @@ class NeptuneAnalyticsHook(AwsBaseHook): def __init__(self, *args, **kwargs): kwargs["client_type"] = "neptune-graph" super().__init__(*args, **kwargs) + + def _get_graph_endpoint_id(self, graph_id: str, vpc_id: str): + """Return the vpc endpoint id for this graph.""" + result = self.conn.get_private_graph_endpoint(graphIdentifier=graph_id, vpcId=vpc_id) + return result.get("vpcEndpointId") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py index 23cb6925798fd..ecfe15d193eef 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune_analytics.py @@ -34,6 +34,7 @@ NeptuneImportTaskCancelledTrigger, NeptuneImportTaskCompleteTrigger, ) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.common.compat.sdk import conf @@ -44,6 +45,7 @@ NeptuneGraphCreationFailedError, NeptuneGraphDeletionFailedError, NeptuneImportTaskCancellationFailedError, + NeptuneImportTaskFailedError, NeptunePrivateEndpointCreationFailedError, NeptunePrivateEndpointDeletionFailedError, ) @@ -119,7 +121,7 @@ def __init__( self.provisioned_memory = provisioned_memory self.public_connectivity = public_connectivity self.deletion_protect = deletion_protection - self.kms_key = kms_key_id + self.kms_key_id = kms_key_id self.tags = tags self.wait_for_completion = wait_for_completion self.deferrable = deferrable @@ -139,7 +141,7 @@ def execute(self, context: Context) -> dict: "replicaCount": self.replica_count, "publicConnectivity": self.public_connectivity, "deletionProtection": self.deletion_protect, - "kmsKeyIdentifier": self.kms_key, + "kmsKeyIdentifier": self.kms_key_id, "tags": self.tags, }.items() if v is not None @@ -188,16 +190,14 @@ def execute(self, context: Context) -> dict: return {"graph_id": self.graph_id} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - if event is None: - raise NeptuneGraphCreationFailedError( - "No event received while waiting for Neptune graph creation to complete." - ) - if event.get("status") != "success": + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": raise NeptuneGraphCreationFailedError( - event.get( + validated_event.get( "message", - f"Neptune graph {self.graph_id} creation did not complete successfully: {event}", + f"Neptune graph {validated_event.get('return_key')} creation did not complete successfully", ) ) @@ -293,6 +293,24 @@ def execute(self, context: Context) -> dict: # if VPC not provided, use the one that is returned, which is the default VPC. Required for the waiter self.vpc_id = result.get("vpcId", self.vpc_id) + # get the vpce id since it may not be returned immediately + endpoint_id = self.hook._get_graph_endpoint_id(graph_id=self.graph_identifier, vpc_id=self.vpc_id) + + endpoint_url = VpcEndpointLink.format_str.format( + endpoint_id=endpoint_id, + aws_domain=VpcEndpointLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + ) + + VpcEndpointLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + endpoint_id=endpoint_id, + ) + self.log.info("You can view this private endpoint at : %s", endpoint_url) + if self.deferrable: self.log.info("Deferring until endpoint is available") self.defer( @@ -304,6 +322,7 @@ def execute(self, context: Context) -> dict: waiter_max_attempts=self.waiter_max_attempts, ), method_name="execute_complete", + kwargs={"vpc_id": self.vpc_id}, ) if self.wait_for_completion: @@ -314,40 +333,20 @@ def execute(self, context: Context) -> dict: WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - endpoint_id = self._get_graph_endpoint_id() - - endpoint_url = VpcEndpointLink.format_str.format( - endpoint_id=endpoint_id, - aws_domain=VpcEndpointLink.get_aws_domain(self.hook.conn_partition), - region_name=self.hook.conn_region_name, - ) - - VpcEndpointLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - endpoint_id=endpoint_id, - ) - self.log.info("You can view this private endpoint at : %s", endpoint_url) - return {"vpc_endpoint_id": endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} - def _get_graph_endpoint_id(self): - """Return the vpc endpoint id for this graph.""" - result = self.hook.conn.get_private_graph_endpoint( - graphIdentifier=self.graph_identifier, vpcId=self.vpc_id - ) - return result.get("vpcEndpointId") - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - vpc_endpoint_id = self._get_graph_endpoint_id() - if event: - status = event["status"] - if status.lower() != "success": - raise NeptunePrivateEndpointCreationFailedError("Endpoint failed to create") + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptunePrivateEndpointCreationFailedError( + validated_event.get("message", "Endpoint failed to create") + ) - return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": self.graph_identifier, "vpc_id": self.vpc_id} + graph_id = validated_event.get("value") + vpc_id = validated_event.get("vpc_id") + vpc_endpoint_id = self.hook._get_graph_endpoint_id(graph_id=graph_id, vpc_id=vpc_id) + return {"vpc_endpoint_id": vpc_endpoint_id, "graph_id": graph_id, "vpc_id": vpc_id} class NeptuneDeletePrivateGraphEndpointOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -437,14 +436,15 @@ def execute(self, context: Context) -> None: self.log.info("Endpoint %s deleted", endpoint_id) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - vpc_endpoint_id = "" + validated_event = validate_execute_complete_event(event) - if event and (event.get("status") or "").lower() == "success": - vpc_endpoint_id = event.get("endpoint_id", "Unknown") + if validated_event.get("status") != "success": + raise NeptunePrivateEndpointDeletionFailedError( + validated_event.get("message", "Endpoint failed to delete.") + ) - self.log.info("Endpoint id %s deleted", vpc_endpoint_id) - else: - raise NeptunePrivateEndpointDeletionFailedError("Endpoint failed to delete.") + vpc_endpoint_id = validated_event.get("endpoint_id", "Unknown") + self.log.info("Endpoint id %s deleted", vpc_endpoint_id) class NeptuneDeleteGraphOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -530,10 +530,14 @@ def execute(self, context: Context): ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None): - if event: - graph_id = event.get("graph_id", "Unknown") + validated_event = validate_execute_complete_event(event) + graph_id = validated_event.get("graph_id") + if validated_event.get("status") != "success": + raise NeptuneGraphDeletionFailedError( + validated_event.get("message", f"Neptune graph {graph_id} deletion failed") + ) - self.log.info("Neptune graph %s deleted", graph_id) + self.log.info("Neptune graph %s deleted", validated_event.get("graph_id", graph_id)) class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -583,7 +587,7 @@ class NeptuneCreateGraphWithImportOperator(AwsBaseOperator[NeptuneAnalyticsHook] aws_hook_class = NeptuneAnalyticsHook template_fields: Sequence[str] = aws_template_fields( - "graph_name", "vector_search_config", "source", "role_arn", "kms_key" + "graph_name", "vector_search_config", "source", "role_arn" ) template_fields_renderers = { @@ -633,7 +637,7 @@ def __init__( self.public_connectivity = public_connectivity self.replica_count = replica_count self.deletion_protect = deletion_protection - self.kms_key = kms_key_id + self.kms_key_id = kms_key_id self.tags = tags self.import_options = import_options self.wait_for_completion = wait_for_completion @@ -672,7 +676,7 @@ def execute(self, context: Context) -> dict: "replicaCount": self.replica_count, "publicConnectivity": self.public_connectivity, "deletionProtection": self.deletion_protect, - "kmsKeyIdentifier": self.kms_key, + "kmsKeyIdentifier": self.kms_key_id, "tags": self.tags, "importOptions": import_options if import_options else None, }.items() @@ -751,6 +755,14 @@ def defer_wait_for_task( self, context: Context, event: dict[str, Any] | None = None, import_task_id: str | None = None ) -> None: """Defers for import task completion.""" + validated_event = validate_execute_complete_event(event) + graph_id = validated_event.get("value") + + if validated_event.get("status") != "success": + raise NeptuneGraphCreationFailedError( + validated_event.get("message", f"Neptune graph {graph_id} did not become available") + ) + if import_task_id: self.log.info("Deferring for import task %s completion", import_task_id) self.defer( @@ -761,16 +773,23 @@ def defer_wait_for_task( aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", + kwargs={"graph_id": graph_id}, ) - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - if event is None or event.get("status") != "success": - message = (event or {}).get( - "message", f"Neptune graph {self.graph_id} import did not complete successfully" + def execute_complete( + self, context: Context, event: dict[str, Any] | None = None, graph_id: str | None = None + ) -> dict[str, Any]: + validated_event = validate_execute_complete_event(event) + + if validated_event.get("status") != "success": + raise NeptuneGraphCreationFailedError( + validated_event.get( + "message", f"Neptune graph {graph_id} import did not complete successfully" + ) ) - raise NeptuneGraphCreationFailedError(message) - self.log.info("Import complete for graph %s", self.graph_id) - return {"graph_id": self.graph_id} + + self.log.info("Import complete for graph %s", graph_id) + return {"graph_id": graph_id} class NeptuneStartImportTaskOperator(AwsBaseOperator[NeptuneAnalyticsHook]): @@ -907,10 +926,15 @@ def execute(self, context: Context) -> dict: return {"import_task_id": import_task_id, "graph_id": self.graph_identifier} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - task_id = "" - if event: - task_id = event.get("import_task_id", "") + validated_event = validate_execute_complete_event(event) + if validated_event.get("status") != "success": + raise NeptuneImportTaskFailedError( + validated_event.get("message", "Import task did not complete successfully") + ) + + task_id = validated_event.get("import_task_id", "") + self.log.info("Import task %s completed", task_id) return {"graph_id": self.graph_identifier, "import_task_id": task_id} @@ -989,18 +1013,13 @@ def execute(self, context: Context) -> dict: return {"import_task_id": self.import_task_id} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: - if event is None: - raise NeptuneImportTaskCancellationFailedError( - "No event received while waiting for Neptune import task cancellation." - ) + validated_event = validate_execute_complete_event(event) - status = str(event.get("status", "")).lower() - if status not in {"success", "cancelled", "canceled"}: - message = event.get("message") or event.get("error") or f"Unexpected trigger status: {status!r}" + if validated_event.get("status") != "success": raise NeptuneImportTaskCancellationFailedError( - f"Error while waiting for Neptune import task cancellation: {message}" + validated_event.get("message", "Error while waiting for Neptune import task cancellation") ) - task_id = event.get("import_task_id", "") + task_id = validated_event.get("value", "") self.log.info("Import task %s cancelled", task_id) return {"import_task_id": task_id} diff --git a/providers/amazon/src/airflow/providers/amazon/aws/waiters/neptune_analytics.json b/providers/amazon/src/airflow/providers/amazon/aws/waiters/neptune_analytics.json new file mode 100644 index 0000000000000..e5a8e82712f3f --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/neptune_analytics.json @@ -0,0 +1,38 @@ +{ + "version": 2, + "waiters": { + + "import_task_cancelled":{ + "operation": "GetImportTask", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "status", + "expected": "SUCCEEDED", + "state": "success" + }, + { + "matcher": "path", + "argument": "status", + "expected": "CANCELLED", + "state": "success" + }, + { + "matcher": "path", + "argument": "status", + "expected": "ERROR_ENCOUNTERED", + "state": "error" + }, + { + "matcher": "path", + "argument": "status", + "expected": "FAILED", + "state": "success" + } + ] + } + + + }} diff --git a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py index 3768ec0f407c7..79c64f951ce3d 100644 --- a/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py +++ b/providers/amazon/tests/system/amazon/aws/example_neptune_analytics.py @@ -50,9 +50,7 @@ DAG_ID = "example_neptune_analytics" -ROLE_ARN_KEY = "ROLE_ARN" - -sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() +sys_test_context_task = SystemTestContextBuilder().build() # Minimal OpenCypher CSV data for import testing. NODES_CSV = """~id,~label,name:String @@ -114,6 +112,7 @@ def delete_neptune_import_role(role_name: str) -> None: iam_client = boto3.client("iam") with contextlib.suppress(iam_client.exceptions.NoSuchEntityException): iam_client.delete_role_policy(RoleName=role_name, PolicyName="NeptuneAnalyticsS3Access") + iam_client.delete_role(RoleName=role_name) @task(trigger_rule=TriggerRule.ALL_DONE) @@ -127,8 +126,20 @@ def delete_graph_if_exists(graph_name: str) -> None: for graph in page.get("graphs", []): if graph.get("name") == graph_name: graph_id = graph["id"] - # Disable deletion protection if enabled + # Delete any attached private graph endpoints before deleting the graph + endpoints_paginator = hook.conn.get_paginator("list_private_graph_endpoints") + for ep_page in endpoints_paginator.paginate(graphIdentifier=graph_id): + for endpoint in ep_page.get("privateGraphEndpoints", []): + vpc_id = endpoint["vpcId"] + hook.conn.delete_private_graph_endpoint(graphIdentifier=graph_id, vpcId=vpc_id) + hook.conn.get_waiter("private_graph_endpoint_deleted").wait( + graphIdentifier=graph_id, + vpcId=vpc_id, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + + # Disable deletion protection if enabled hook.conn.update_graph(graphIdentifier=graph_id, deletionProtection=False) hook.conn.delete_graph(graphIdentifier=graph_id, skipSnapshot=True) @@ -225,7 +236,7 @@ def delete_graph_if_exists(graph_name: str) -> None: source=f"s3://{bucket_name}/data/", format="CSV", fail_on_error=True, - wait_for_completion=True, + wait_for_completion=False, deferrable=False, waiter_delay=30, waiter_max_attempts=60, @@ -298,8 +309,8 @@ def delete_graph_if_exists(graph_name: str) -> None: delete_role = delete_neptune_import_role(import_role_name) - cleanup_graph = delete_graph_if_exists(graph_name) - cleanup_import_graph = delete_graph_if_exists(import_graph_name) + cleanup_graph = delete_graph_if_exists.override(task_id="cleanup_graph")(graph_name) + cleanup_import_graph = delete_graph_if_exists.override(task_id="cleanup_import_graph")(import_graph_name) chain( # TEST SETUP diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py index 23cf28044135d..0e22d6361bbab 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_neptune_analytics.py @@ -34,13 +34,29 @@ def neptune_hook() -> Generator[NeptuneAnalyticsHook, None, None]: class TestNeptuneAnalyticsHook: - graph_id = "abc123" - def test_get_conn_returns_a_boto3_connection(self): hook = NeptuneAnalyticsHook(aws_conn_id="aws_default") assert hook.get_conn() is not None - @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") - def test_wait_for_graph_availability(self, mock_get_waiter, neptune_hook: NeptuneAnalyticsHook): - waiter = mock_get_waiter("graph_available") - assert waiter is not None + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_get_graph_endpoint_id(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": "vpce-12345", + } + + hook = NeptuneAnalyticsHook(aws_conn_id="aws_default") + result = hook._get_graph_endpoint_id(graph_id="g-abc123", vpc_id="vpc-99999") + + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier="g-abc123", vpcId="vpc-99999" + ) + assert result == "vpce-12345" + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_get_graph_endpoint_id_missing_key(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = {} + + hook = NeptuneAnalyticsHook(aws_conn_id="aws_default") + result = hook._get_graph_endpoint_id(graph_id="g-abc123", vpc_id="vpc-99999") + + assert result is None diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py index 295b122863d6c..fdc0eeec9b1a8 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py @@ -24,7 +24,6 @@ from airflow.providers.amazon.aws.exceptions import ( NeptuneGraphDeletionFailedError, - NeptuneImportTaskCancellationFailedError, NeptunePrivateEndpointCreationFailedError, NeptunePrivateEndpointDeletionFailedError, ) @@ -86,7 +85,7 @@ def test_init_defaults(self, mock_conn, mock_persist): assert operator.public_connectivity is None assert operator.replica_count is None assert operator.deletion_protect is False - assert operator.kms_key is None + assert operator.kms_key_id is None assert operator.tags is None operator.execute(None) @@ -118,7 +117,7 @@ def test_init_custom_args(self, mock_conn, mock_persist): assert operator.public_connectivity is True assert operator.replica_count == 3 assert operator.deletion_protect is True - assert operator.kms_key == "test-key" + assert operator.kms_key_id == "test-key" assert operator.tags == {"key1": "test"} operator.execute(None) @@ -394,40 +393,29 @@ def test_create_endpoint_failed_status(self, mock_conn): operator.execute(None) @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_execute_complete(self, mock_conn): + @mock.patch.object(NeptuneAnalyticsHook, "_get_graph_endpoint_id") + def test_execute_complete(self, mock_get_endpoint, mock_conn): + mock_get_endpoint.return_value = ENDPOINT_ID mock_conn.get_private_graph_endpoint.return_value = { "vpcEndpointId": ENDPOINT_ID, } - operator = NeptuneCreatePrivateGraphEndpointOperator( - task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID - ) - - result = operator.execute_complete(None, {"status": "success"}) - - mock_conn.get_private_graph_endpoint.assert_called_once_with( - graphIdentifier=GRAPH_ID, - vpcId=VPC_ID, - ) - assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} - - @mock.patch.object(NeptuneAnalyticsHook, "conn") - def test_get_graph_endpoint_id(self, mock_conn): - mock_conn.get_private_graph_endpoint.return_value = { - "vpcEndpointId": ENDPOINT_ID, - } + mock_conn.create_private_graph_endpoint.return_value = {"vpcId": VPC_ID} operator = NeptuneCreatePrivateGraphEndpointOperator( task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID ) - result = operator._get_graph_endpoint_id() + result = operator.execute_complete( + context=None, event={"status": "success", "value": GRAPH_ID, "vpc_id": VPC_ID} + ) - assert result == ENDPOINT_ID - mock_conn.get_private_graph_endpoint.assert_called_once_with( - graphIdentifier=GRAPH_ID, - vpcId=VPC_ID, + # mock_conn.get_private_graph_endpoint.assert_called_once_with( + mock_get_endpoint.assert_called_once_with( + graph_id=GRAPH_ID, + vpc_id=VPC_ID, ) + assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} class TestNeptuneDeletePrivateGraphEndpointOperator: @@ -749,7 +737,7 @@ def test_init_defaults(self, mock_conn): assert operator.public_connectivity is None assert operator.replica_count is None assert operator.deletion_protect is None - assert operator.kms_key is None + assert operator.kms_key_id is None assert operator.tags is None assert operator.import_options is None assert operator.wait_for_completion is True @@ -804,7 +792,7 @@ def test_init_with_all_optional_params(self, mock_conn): assert operator.public_connectivity is True assert operator.replica_count == 2 assert operator.deletion_protect is True - assert operator.kms_key == "test-kms-key" + assert operator.kms_key_id == "test-kms-key" assert operator.tags == {"env": "test"} assert operator.import_options == {"custom-option": "value"} assert operator.waiter_delay == 60 @@ -1196,18 +1184,6 @@ def test_execute_complete_success(self): assert result == {"graph_id": GRAPH_ID, "import_task_id": TASK_ID} - def test_execute_complete_no_event(self): - operator = NeptuneStartImportTaskOperator( - task_id="test_task", - graph_identifier=GRAPH_ID, - role_arn=ROLE_ARN, - source=SOURCE_S3_URI, - ) - - result = operator.execute_complete(None, None) - - assert result == {"graph_id": GRAPH_ID, "import_task_id": ""} - class TestNeptuneCancelImportTaskOperator: @mock.patch.object(NeptuneAnalyticsHook, "conn") @@ -1295,16 +1271,7 @@ def test_execute_complete_success(self): import_task_id=TASK_ID, ) - event = {"status": "success", "import_task_id": TASK_ID} + event = {"status": "success", "value": TASK_ID} result = operator.execute_complete(None, event) assert result == {"import_task_id": TASK_ID} - - def test_execute_complete_no_event(self): - operator = NeptuneCancelImportTaskOperator( - task_id="test_task", - import_task_id=TASK_ID, - ) - - with pytest.raises(NeptuneImportTaskCancellationFailedError): - operator.execute_complete(None, None)