From 08d15d06ba8675d70fcbd19f0500d67fc5f310cd Mon Sep 17 00:00:00 2001 From: Ahzaz Hingora <5833893+ahzaz@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:51:52 +0530 Subject: [PATCH] Add support for driver pool, instance flexibility policy, and min_num_instances for Dataproc (#34172) --- .../google/cloud/operators/dataproc.py | 85 +++++++++++ airflow/providers/google/provider.yaml | 2 +- docs/spelling_wordlist.txt | 2 + .../google/cloud/operators/test_dataproc.py | 140 ++++++++++++++++++ 4 files changed, 228 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 8d3387a700fa2..b489a79dc8ee8 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -25,6 +25,7 @@ import time import uuid import warnings +from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum from typing import TYPE_CHECKING, Any, Sequence @@ -77,6 +78,38 @@ class PreemptibilityType(Enum): NON_PREEMPTIBLE = "NON_PREEMPTIBLE" +@dataclass +class InstanceSelection: + """Defines machines types and a rank to which the machines types belong. + + Representation for + google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.InstanceSelection. + + :param machine_types: Full machine-type names, e.g. "n1-standard-16". + :param rank: Preference of this instance selection. Lower number means higher preference. + Dataproc will first try to create a VM based on the machine-type with priority rank and fallback + to next rank based on availability. Machine types and instance selections with the same priority have + the same preference. + """ + + machine_types: list[str] + rank: int = 0 + + +@dataclass +class InstanceFlexibilityPolicy: + """ + Instance flexibility Policy allowing a mixture of VM shapes and provisioning models. + + Representation for google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy. + + :param instance_selection_list: List of instance selection options that the group will use when + creating new VMs. + """ + + instance_selection_list: list[InstanceSelection] + + class ClusterGenerator: """Create a new Dataproc Cluster. @@ -85,6 +118,11 @@ class ClusterGenerator: to create the cluster. (templated) :param num_workers: The # of workers to spin up. If set to zero will spin up cluster in a single node mode + :param min_num_workers: The minimum number of primary worker instances to create. + If more than ``min_num_workers`` VMs are created out of ``num_workers``, the failed VMs will be + deleted, cluster is resized to available VMs and set to RUNNING. + If created VMs are less than ``min_num_workers``, the cluster is placed in ERROR state. The failed + VMs are not deleted. :param storage_bucket: The storage bucket to use, setting to None lets dataproc generate a custom one for you :param init_actions_uris: List of GCS uri's containing @@ -153,12 +191,18 @@ class ClusterGenerator: ``projects/[PROJECT_STORING_KEYS]/locations/[LOCATION]/keyRings/[KEY_RING_NAME]/cryptoKeys/[KEY_NAME]`` # noqa :param enable_component_gateway: Provides access to the web interfaces of default and selected optional components on the cluster. + :param driver_pool_size: The number of driver nodes in the node group. + :param driver_pool_id: The ID for the driver pool. Must be unique within the cluster. Use this ID to + identify the driver group in future operations, such as resizing the node group. + :param secondary_worker_instance_flexibility_policy: Instance flexibility Policy allowing a mixture of VM + shapes and provisioning models. """ def __init__( self, project_id: str, num_workers: int | None = None, + min_num_workers: int | None = None, zone: str | None = None, network_uri: str | None = None, subnetwork_uri: str | None = None, @@ -191,11 +235,15 @@ def __init__( auto_delete_ttl: int | None = None, customer_managed_key: str | None = None, enable_component_gateway: bool | None = False, + driver_pool_size: int = 0, + driver_pool_id: str | None = None, + secondary_worker_instance_flexibility_policy: InstanceFlexibilityPolicy | None = None, **kwargs, ) -> None: self.project_id = project_id self.num_masters = num_masters self.num_workers = num_workers + self.min_num_workers = min_num_workers self.num_preemptible_workers = num_preemptible_workers self.preemptibility = self._set_preemptibility_type(preemptibility) self.storage_bucket = storage_bucket @@ -228,6 +276,9 @@ def __init__( self.customer_managed_key = customer_managed_key self.enable_component_gateway = enable_component_gateway self.single_node = num_workers == 0 + self.driver_pool_size = driver_pool_size + self.driver_pool_id = driver_pool_id + self.secondary_worker_instance_flexibility_policy = secondary_worker_instance_flexibility_policy if self.custom_image and self.image_version: raise ValueError("The custom_image and image_version can't be both set") @@ -241,6 +292,15 @@ def __init__( if self.single_node and self.num_preemptible_workers > 0: raise ValueError("Single node cannot have preemptible workers.") + if self.min_num_workers: + if not self.num_workers: + raise ValueError("Must specify num_workers when min_num_workers are provided.") + if self.min_num_workers > self.num_workers: + raise ValueError( + "The value of min_num_workers must be less than or equal to num_workers. " + f"Provided {self.min_num_workers}(min_num_workers) and {self.num_workers}(num_workers)." + ) + def _set_preemptibility_type(self, preemptibility: str): return PreemptibilityType(preemptibility.upper()) @@ -307,6 +367,17 @@ def _build_lifecycle_config(self, cluster_data): return cluster_data + def _build_driver_pool(self): + driver_pool = { + "node_group": { + "roles": ["DRIVER"], + "node_group_config": {"num_instances": self.driver_pool_size}, + }, + } + if self.driver_pool_id: + driver_pool["node_group_id"] = self.driver_pool_id + return driver_pool + def _build_cluster_data(self): if self.zone: master_type_uri = ( @@ -344,6 +415,10 @@ def _build_cluster_data(self): "autoscaling_config": {}, "endpoint_config": {}, } + + if self.min_num_workers: + cluster_data["worker_config"]["min_num_instances"] = self.min_num_workers + if self.num_preemptible_workers > 0: cluster_data["secondary_worker_config"] = { "num_instances": self.num_preemptible_workers, @@ -355,6 +430,13 @@ def _build_cluster_data(self): "is_preemptible": True, "preemptibility": self.preemptibility.value, } + if self.secondary_worker_instance_flexibility_policy: + cluster_data["secondary_worker_config"]["instance_flexibility_policy"] = { + "instance_selection_list": [ + vars(s) + for s in self.secondary_worker_instance_flexibility_policy.instance_selection_list + ] + } if self.storage_bucket: cluster_data["config_bucket"] = self.storage_bucket @@ -382,6 +464,9 @@ def _build_cluster_data(self): if not self.single_node: cluster_data["worker_config"]["image_uri"] = custom_image_url + if self.driver_pool_size > 0: + cluster_data["auxiliary_node_groups"] = [self._build_driver_pool()] + cluster_data = self._build_gce_cluster_config(cluster_data) if self.single_node: diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 0500304a0d641..4287801d2b204 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -102,7 +102,7 @@ dependencies: - google-cloud-dataflow-client>=0.8.2 - google-cloud-dataform>=0.5.0 - google-cloud-dataplex>=1.4.2 - - google-cloud-dataproc>=5.4.0 + - google-cloud-dataproc>=5.5.0 - google-cloud-dataproc-metastore>=1.12.0 - google-cloud-dlp>=3.12.0 - google-cloud-kms>=2.15.0 diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b787191fd0462..c56f81ebaf75a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -792,7 +792,9 @@ InspectContentResponse InspectTemplate instafail installable +InstanceFlexibilityPolicy InstanceGroupConfig +InstanceSelection instanceTemplates instantiation integrations diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 39361c5a9840c..59a9c1008c4a3 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -60,6 +60,8 @@ DataprocSubmitSparkJobOperator, DataprocSubmitSparkSqlJobOperator, DataprocUpdateClusterOperator, + InstanceFlexibilityPolicy, + InstanceSelection, ) from airflow.providers.google.cloud.triggers.dataproc import ( DataprocBatchTrigger, @@ -112,6 +114,7 @@ "disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256}, "image_uri": "https://www.googleapis.com/compute/beta/projects/" "custom_image_project_id/global/images/custom_image", + "min_num_instances": 1, }, "secondary_worker_config": { "num_instances": 4, @@ -132,6 +135,17 @@ {"executable_file": "init_actions_uris", "execution_timeout": {"seconds": 600}} ], "endpoint_config": {}, + "auxiliary_node_groups": [ + { + "node_group": { + "roles": ["DRIVER"], + "node_group_config": { + "num_instances": 2, + }, + }, + "node_group_id": "cluster_driver_pool", + } + ], } VIRTUAL_CLUSTER_CONFIG = { "kubernetes_cluster_config": { @@ -197,6 +211,64 @@ }, } +CONFIG_WITH_FLEX_MIG = { + "gce_cluster_config": { + "zone_uri": "https://www.googleapis.com/compute/v1/projects/project_id/zones/zone", + "metadata": {"metadata": "data"}, + "network_uri": "network_uri", + "subnetwork_uri": "subnetwork_uri", + "internal_ip_only": True, + "tags": ["tags"], + "service_account": "service_account", + "service_account_scopes": ["service_account_scopes"], + }, + "master_config": { + "num_instances": 2, + "machine_type_uri": "projects/project_id/zones/zone/machineTypes/master_machine_type", + "disk_config": {"boot_disk_type": "master_disk_type", "boot_disk_size_gb": 128}, + "image_uri": "https://www.googleapis.com/compute/beta/projects/" + "custom_image_project_id/global/images/custom_image", + }, + "worker_config": { + "num_instances": 2, + "machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type", + "disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256}, + "image_uri": "https://www.googleapis.com/compute/beta/projects/" + "custom_image_project_id/global/images/custom_image", + }, + "secondary_worker_config": { + "num_instances": 4, + "machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type", + "disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256}, + "is_preemptible": True, + "preemptibility": "SPOT", + "instance_flexibility_policy": { + "instance_selection_list": [ + { + "machine_types": [ + "projects/project_id/zones/zone/machineTypes/machine1", + "projects/project_id/zones/zone/machineTypes/machine2", + ], + "rank": 0, + }, + {"machine_types": ["projects/project_id/zones/zone/machineTypes/machine3"], "rank": 1}, + ], + }, + }, + "software_config": {"properties": {"properties": "data"}, "optional_components": ["optional_components"]}, + "lifecycle_config": { + "idle_delete_ttl": {"seconds": 60}, + "auto_delete_time": "2019-09-12T00:00:00.000000Z", + }, + "encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"}, + "autoscaling_config": {"policy_uri": "autoscaling_policy"}, + "config_bucket": "storage_bucket", + "initialization_actions": [ + {"executable_file": "init_actions_uris", "execution_timeout": {"seconds": 600}} + ], + "endpoint_config": {}, +} + LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION} LABELS.update({"airflow-version": "v" + airflow_version.replace(".", "-").replace("+", "-")}) @@ -361,10 +433,26 @@ def test_nodes_number(self): ) assert "num_workers == 0 means single" in str(ctx.value) + def test_min_num_workers_less_than_num_workers(self): + with pytest.raises(ValueError) as ctx: + ClusterGenerator( + num_workers=3, min_num_workers=4, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME + ) + assert ( + "The value of min_num_workers must be less than or equal to num_workers. " + "Provided 4(min_num_workers) and 3(num_workers)." in str(ctx.value) + ) + + def test_min_num_workers_without_num_workers(self): + with pytest.raises(ValueError) as ctx: + ClusterGenerator(min_num_workers=4, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME) + assert "Must specify num_workers when min_num_workers are provided." in str(ctx.value) + def test_build(self): generator = ClusterGenerator( project_id="project_id", num_workers=2, + min_num_workers=1, zone="zone", network_uri="network_uri", subnetwork_uri="subnetwork_uri", @@ -395,6 +483,8 @@ def test_build(self): auto_delete_time=datetime(2019, 9, 12), auto_delete_ttl=250, customer_managed_key="customer_managed_key", + driver_pool_id="cluster_driver_pool", + driver_pool_size=2, ) cluster = generator.make() assert CONFIG == cluster @@ -438,6 +528,56 @@ def test_build_with_custom_image_family(self): cluster = generator.make() assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster + def test_build_with_flex_migs(self): + generator = ClusterGenerator( + project_id="project_id", + num_workers=2, + zone="zone", + network_uri="network_uri", + subnetwork_uri="subnetwork_uri", + internal_ip_only=True, + tags=["tags"], + storage_bucket="storage_bucket", + init_actions_uris=["init_actions_uris"], + init_action_timeout="10m", + metadata={"metadata": "data"}, + custom_image="custom_image", + custom_image_project_id="custom_image_project_id", + autoscaling_policy="autoscaling_policy", + properties={"properties": "data"}, + optional_components=["optional_components"], + num_masters=2, + master_machine_type="master_machine_type", + master_disk_type="master_disk_type", + master_disk_size=128, + worker_machine_type="worker_machine_type", + worker_disk_type="worker_disk_type", + worker_disk_size=256, + num_preemptible_workers=4, + preemptibility="Spot", + region="region", + service_account="service_account", + service_account_scopes=["service_account_scopes"], + idle_delete_ttl=60, + auto_delete_time=datetime(2019, 9, 12), + auto_delete_ttl=250, + customer_managed_key="customer_managed_key", + secondary_worker_instance_flexibility_policy=InstanceFlexibilityPolicy( + [ + InstanceSelection( + [ + "projects/project_id/zones/zone/machineTypes/machine1", + "projects/project_id/zones/zone/machineTypes/machine2", + ], + 0, + ), + InstanceSelection(["projects/project_id/zones/zone/machineTypes/machine3"], 1), + ] + ), + ) + cluster = generator.make() + assert CONFIG_WITH_FLEX_MIG == cluster + class TestDataprocClusterCreateOperator(DataprocClusterTestBase): def test_deprecation_warning(self):