diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index dae5535e40b3c..4551b24384c7f 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -583,6 +583,94 @@ def update_cluster( ) return operation + @GoogleBaseHook.fallback_to_default_project_id + def start_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_uuid: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """Start a cluster in a project. + + :param region: Cloud Dataproc region to handle the request. + :param project_id: Google Cloud project ID that the cluster belongs to. + :param cluster_name: The cluster name. + :param cluster_uuid: The cluster UUID + :param request_id: A unique id used to identify the request. If the + server receives two *UpdateClusterRequest* requests with the same + ID, the second request will be ignored, and an operation created + for the first one and stored in the backend is returned. + :param retry: A retry object used to retry requests. If *None*, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request + to complete. If *retry* is specified, the timeout applies to each + individual attempt. + :param metadata: Additional metadata that is provided to the method. + :return: An instance of ``google.api_core.operation.Operation`` + """ + client = self.get_cluster_client(region=region) + return client.start_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + "cluster_uuid": cluster_uuid, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def stop_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_uuid: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """Start a cluster in a project. + + :param region: Cloud Dataproc region to handle the request. + :param project_id: Google Cloud project ID that the cluster belongs to. + :param cluster_name: The cluster name. + :param cluster_uuid: The cluster UUID + :param request_id: A unique id used to identify the request. If the + server receives two *UpdateClusterRequest* requests with the same + ID, the second request will be ignored, and an operation created + for the first one and stored in the backend is returned. + :param retry: A retry object used to retry requests. If *None*, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request + to complete. If *retry* is specified, the timeout applies to each + individual attempt. + :param metadata: Additional metadata that is provided to the method. + :return: An instance of ``google.api_core.operation.Operation`` + """ + client = self.get_cluster_client(region=region) + return client.stop_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + "cluster_uuid": cluster_uuid, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + @GoogleBaseHook.fallback_to_default_project_id def create_workflow_template( self, diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 7f3fcd5d01565..aacc1adb24770 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -724,6 +724,17 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: cluster = self._get_cluster(hook) return cluster + def _start_cluster(self, hook: DataprocHook): + op: operation.Operation = hook.start_cluster( + region=self.region, + project_id=self.project_id, + cluster_name=self.cluster_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op) + def execute(self, context: Context) -> dict: self.log.info("Creating cluster: %s", self.cluster_name) hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) @@ -801,6 +812,9 @@ def execute(self, context: Context) -> dict: # Create new cluster cluster = self._create_cluster(hook) self._handle_error_state(hook, cluster) + elif cluster.status.state == cluster.status.State.STOPPED: + # if the cluster exists and already stopped, then start the cluster + self._start_cluster(hook) return Cluster.to_dict(cluster) @@ -1082,6 +1096,189 @@ def _delete_cluster(self, hook: DataprocHook): ) +class _DataprocStartStopClusterBaseOperator(GoogleCloudBaseOperator): + """Base class to start or stop a cluster in a project. + + :param cluster_name: Required. Name of the cluster to create + :param region: Required. The specified region where the dataproc cluster is created. + :param project_id: Optional. The ID of the Google Cloud project the cluster belongs to. + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "cluster_name", + "region", + "project_id", + "request_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + cluster_name: str, + region: str, + project_id: str | None = None, + cluster_uuid: str | None = None, + request_id: str | None = None, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float = 1 * 60 * 60, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.cluster_name = cluster_name + self.cluster_uuid = cluster_uuid + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self._hook: DataprocHook | None = None + + @property + def hook(self): + if self._hook is None: + self._hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return self._hook + + def _get_project_id(self) -> str: + return self.project_id or self.hook.project_id + + def _get_cluster(self) -> Cluster: + """Retrieve the cluster information. + + :return: Instance of ``google.cloud.dataproc_v1.Cluster``` class + """ + return self.hook.get_cluster( + project_id=self._get_project_id(), + region=self.region, + cluster_name=self.cluster_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]: + """Implement this method in child class to return whether the cluster is in desired state or not. + + If the cluster is in desired stated you can return a log message content as a second value + for the return tuple. + + :param cluster: Required. Instance of ``google.cloud.dataproc_v1.Cluster`` + class to interact with Dataproc API + :return: Tuple of (Boolean, Optional[str]) The first value of the tuple is whether the cluster is + in desired state or not. The second value of the tuple will use if you want to log something when + the cluster is in desired state already. + """ + raise NotImplementedError + + def _get_operation(self) -> operation.Operation: + """Implement this method in child class to call the related hook method and return its result. + + :return: ``google.api_core.operation.Operation`` value whether the cluster is in desired state or not + """ + raise NotImplementedError + + def execute(self, context: Context) -> dict | None: + cluster: Cluster = self._get_cluster() + is_already_desired_state, log_str = self._check_desired_cluster_state(cluster) + if is_already_desired_state: + self.log.info(log_str) + return None + + op: operation.Operation = self._get_operation() + result = self.hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op) + return Cluster.to_dict(result) + + +class DataprocStartClusterOperator(_DataprocStartStopClusterBaseOperator): + """Start a cluster in a project.""" + + operator_extra_links = (DataprocClusterLink(),) + + def execute(self, context: Context) -> dict | None: + self.log.info("Starting the cluster: %s", self.cluster_name) + cluster = super().execute(context) + DataprocClusterLink.persist( + context=context, + operator=self, + cluster_id=self.cluster_name, + project_id=self._get_project_id(), + region=self.region, + ) + self.log.info("Cluster started") + return cluster + + def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]: + if cluster.status.state == cluster.status.State.RUNNING: + return True, f'The cluster "{self.cluster_name}" already running!' + return False, None + + def _get_operation(self) -> operation.Operation: + return self.hook.start_cluster( + region=self.region, + project_id=self._get_project_id(), + cluster_name=self.cluster_name, + cluster_uuid=self.cluster_uuid, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class DataprocStopClusterOperator(_DataprocStartStopClusterBaseOperator): + """Stop a cluster in a project.""" + + def execute(self, context: Context) -> dict | None: + self.log.info("Stopping the cluster: %s", self.cluster_name) + cluster = super().execute(context) + self.log.info("Cluster stopped") + return cluster + + def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]: + if cluster.status.state in [cluster.status.State.STOPPED, cluster.status.State.STOPPING]: + return True, f'The cluster "{self.cluster_name}" already stopped!' + return False, None + + def _get_operation(self) -> operation.Operation: + return self.hook.stop_cluster( + region=self.region, + project_id=self._get_project_id(), + cluster_name=self.cluster_name, + cluster_uuid=self.cluster_uuid, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + class DataprocJobBaseOperator(GoogleCloudBaseOperator): """Base class for operators that launch job on DataProc. diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst index 67c2831a1aaac..6277f94e05287 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst @@ -201,6 +201,30 @@ You can use deferrable mode for this action in order to run the operator asynchr :start-after: [START how_to_cloud_dataproc_update_cluster_operator_async] :end-before: [END how_to_cloud_dataproc_update_cluster_operator_async] +Starting a cluster +--------------------------- + +To start a cluster you can use the +:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_dataproc_start_cluster_operator] + :end-before: [END how_to_cloud_dataproc_start_cluster_operator] + +Stopping a cluster +--------------------------- + +To stop a cluster you can use the +:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_dataproc_stop_cluster_operator] + :end-before: [END how_to_cloud_dataproc_stop_cluster_operator] + Deleting a cluster ------------------ diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index db026aa6bfcde..bab56abead2a1 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -403,6 +403,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryToSqlBaseOperator", "airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator", "airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator", + "airflow.providers.google.cloud.operators.dataproc._DataprocStartStopClusterBaseOperator", "airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator", "airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator", } diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py index 1a82fc8a1cb0b..131f5a342b504 100644 --- a/tests/providers/google/cloud/hooks/test_dataproc.py +++ b/tests/providers/google/cloud/hooks/test_dataproc.py @@ -287,6 +287,48 @@ def test_update_cluster_missing_region(self, mock_client): update_mask="update-mask", ) + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) + def test_start_cluster(self, mock_client): + self.hook.start_cluster( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + ) + mock_client.assert_called_once_with(region=GCP_LOCATION) + mock_client.return_value.start_cluster.assert_called_once_with( + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + cluster_uuid=None, + request_id=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) + def test_stop_cluster(self, mock_client): + self.hook.stop_cluster( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + ) + mock_client.assert_called_once_with(region=GCP_LOCATION) + mock_client.return_value.stop_cluster.assert_called_once_with( + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + cluster_uuid=None, + request_id=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_create_workflow_template(self, mock_client): template = {"test": "test"} diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index d0b04a6fa95fb..44e20489a2406 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -54,6 +54,8 @@ DataprocLink, DataprocListBatchesOperator, DataprocScaleClusterOperator, + DataprocStartClusterOperator, + DataprocStopClusterOperator, DataprocSubmitHadoopJobOperator, DataprocSubmitHiveJobOperator, DataprocSubmitJobOperator, @@ -1683,6 +1685,90 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_ assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED +class TestDataprocStartClusterOperator(DataprocClusterTestBase): + @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute(self, mock_hook, mock_to_dict): + cluster = MagicMock() + cluster.status.State.RUNNING = 3 + cluster.status.state = 0 + mock_hook.return_value.get_cluster.return_value = cluster + + op = DataprocStartClusterOperator( + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context=self.mock_context) + + mock_hook.return_value.get_cluster.assert_called_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.start_cluster.assert_called_once_with( + cluster_name=CLUSTER_NAME, + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_uuid=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestDataprocStopClusterOperator(DataprocClusterTestBase): + @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute(self, mock_hook, mock_to_dict): + cluster = MagicMock() + cluster.status.State.STOPPED = 4 + cluster.status.state = 0 + mock_hook.return_value.get_cluster.return_value = cluster + + op = DataprocStopClusterOperator( + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context=self.mock_context) + + mock_hook.return_value.get_cluster.assert_called_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.stop_cluster.assert_called_once_with( + cluster_name=CLUSTER_NAME, + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_uuid=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + class TestDataprocInstantiateWorkflowTemplateOperator: @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py new file mode 100644 index 0000000000000..6a77a146844c5 --- /dev/null +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for DataprocCreateClusterOperator in case of the cluster is already existing and stopped. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.dataproc import ( + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocStartClusterOperator, + DataprocStopClusterOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "example_dataproc_cluster_create_existing_stopped_cluster" + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +PROJECT_ID = os.environ.get("SYSTEMS_TESTS_GCP_PROJECTS") + +CLUSTER_NAME = f"cluster-{ENV_ID}-{DAG_ID}".replace("_", "-") +REGION = "europe-west1" + +# Cluster definition +CLUSTER_CONFIG = { + "master_config": { + "num_instances": 1, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, + "worker_config": { + "num_instances": 2, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, +} + +with DAG( + DAG_ID, schedule="@once", start_date=datetime(2024, 1, 1), catchup=False, tags=["dataproc", "example"] +) as dag: + create_cluster = DataprocCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + cluster_config=CLUSTER_CONFIG, + region=REGION, + cluster_name=CLUSTER_NAME, + use_if_exists=True, + ) + + start_cluster = DataprocStartClusterOperator( + task_id="start_cluster", + project_id=PROJECT_ID, + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + stop_cluster = DataprocStopClusterOperator( + task_id="stop_cluster", + project_id=PROJECT_ID, + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + create_cluster_for_stopped_cluster = DataprocCreateClusterOperator( + task_id="create_cluster_for_stopped_cluster", + project_id=PROJECT_ID, + cluster_config=CLUSTER_CONFIG, + region=REGION, + cluster_name=CLUSTER_NAME, + use_if_exists=True, + ) + + delete_cluster = DataprocDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + cluster_name=CLUSTER_NAME, + region=REGION, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + create_cluster + >> stop_cluster + >> start_cluster + # TEST BODY + >> create_cluster_for_stopped_cluster + # TEST TEARDOWN + >> delete_cluster + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py new file mode 100644 index 0000000000000..7dcb127cd62c2 --- /dev/null +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for DataprocStartClusterOperator and DataprocStopClusterOperator. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.dataproc import ( + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocStartClusterOperator, + DataprocStopClusterOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "dataproc_cluster_start_stop" + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +PROJECT_ID = os.environ.get("SYSTEMS_TESTS_GCP_PROJECTS") + +CLUSTER_NAME = f"cluster-{ENV_ID}-{DAG_ID}".replace("_", "-") +REGION = "europe-west1" + +# Cluster definition +CLUSTER_CONFIG = { + "master_config": { + "num_instances": 1, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, + "worker_config": { + "num_instances": 2, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, +} + +with DAG( + DAG_ID, schedule="@once", start_date=datetime(2024, 1, 1), catchup=False, tags=["dataproc", "example"] +) as dag: + create_cluster = DataprocCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + cluster_config=CLUSTER_CONFIG, + region=REGION, + cluster_name=CLUSTER_NAME, + use_if_exists=True, + ) + + # [START how_to_cloud_dataproc_start_cluster_operator] + start_cluster = DataprocStartClusterOperator( + task_id="start_cluster", + project_id=PROJECT_ID, + region=REGION, + cluster_name=CLUSTER_NAME, + ) + # [END how_to_cloud_dataproc_start_cluster_operator] + + # [START how_to_cloud_dataproc_stop_cluster_operator] + stop_cluster = DataprocStopClusterOperator( + task_id="stop_cluster", + project_id=PROJECT_ID, + region=REGION, + cluster_name=CLUSTER_NAME, + ) + # [END how_to_cloud_dataproc_stop_cluster_operator] + + delete_cluster = DataprocDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + cluster_name=CLUSTER_NAME, + region=REGION, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + create_cluster + # TEST BODY + >> stop_cluster + >> start_cluster + # TEST TEARDOWN + >> delete_cluster + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)