diff --git a/airflow/contrib/example_dags/example_gcp_dataproc_create_cluster.py b/airflow/contrib/example_dags/example_gcp_dataproc_create_cluster.py new file mode 100644 index 0000000000000..444416618d11b --- /dev/null +++ b/airflow/contrib/example_dags/example_gcp_dataproc_create_cluster.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import airflow +from airflow import models +from airflow.contrib.operators.dataproc_operator import DataprocClusterCreateOperator + +default_args = {"start_date": airflow.utils.dates.days_ago(1)} + +cluster_name = "testcluster-800" +project_id = "" +region = "" +zone = "" + +# When using autoscaling the number of primary workers +# must be within autoscaler min/max range +num_workers = 2 + +# Autoscaling policy +# The policy must be in the same project and Cloud Dataproc region +# but the zone doesn't have to be the same +scaling_policy = 'test-policy' +policy_uri = 'projects/{p}/locations/{r}/autoscalingPolicies/{id}'.format( + p=project_id, + r=region, + id=scaling_policy +) + +with models.DAG( + "example_dataproc_create_cluster", + default_args=default_args, + schedule_interval=None, +) as dag: + create_task = DataprocClusterCreateOperator( + task_id="run_example_cluster_script", + cluster_name=cluster_name, + project_id=project_id, + num_workers=num_workers, + region=region, + zone=zone, + autoscaling_policy=policy_uri + ) diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py index 9fb52895523ea..1a1f0655000a9 100644 --- a/airflow/contrib/operators/dataproc_operator.py +++ b/airflow/contrib/operators/dataproc_operator.py @@ -99,6 +99,10 @@ class DataprocClusterCreateOperator(DataprocOperationBaseOperator): :param custom_image: custom Dataproc image for more info see https://cloud.google.com/dataproc/docs/guides/dataproc-images :type custom_image: str + :param autoscaling_policy: The autoscaling policy used by the cluster. Only resource names + including projectid and location (region) are valid. Example: + projects/[projectId]/locations/[dataproc_region]/autoscalingPolicies/[policy_id] + :type autoscaling_policy: str :param properties: dict of properties to set on config files (e.g. spark-defaults.conf), see https://cloud.google.com/dataproc/docs/reference/rest/v1/projects.regions.clusters#SoftwareConfig @@ -184,6 +188,7 @@ def __init__(self, metadata=None, custom_image=None, image_version=None, + autoscaling_policy=None, properties=None, master_machine_type='n1-standard-4', master_disk_type='pd-standard', @@ -217,6 +222,7 @@ def __init__(self, self.master_machine_type = master_machine_type self.master_disk_type = master_disk_type self.master_disk_size = master_disk_size + self.autoscaling_policy = autoscaling_policy self.worker_machine_type = worker_machine_type self.worker_disk_type = worker_disk_type self.worker_disk_size = worker_disk_size @@ -293,7 +299,8 @@ def _build_cluster_data(self): 'secondaryWorkerConfig': {}, 'softwareConfig': {}, 'lifecycleConfig': {}, - 'encryptionConfig': {} + 'encryptionConfig': {}, + 'autoscalingConfig': {}, } } if self.num_preemptible_workers > 0: @@ -377,6 +384,9 @@ def _build_cluster_data(self): if self.customer_managed_key: cluster_data['config']['encryptionConfig'] =\ {'gcePdKmsKeyName': self.customer_managed_key} + if self.autoscaling_policy: + cluster_data['config']['autoscalingConfig'] = {'policyUri': self.autoscaling_policy} + return cluster_data def start(self): @@ -723,7 +733,7 @@ def execute(self, context): self.job_template.add_query(self.query) self.job_template.add_variables(self.variables) - self.execute(context) + super().execute(context) class DataProcHiveOperator(DataProcJobBaseOperator): diff --git a/tests/contrib/operators/test_dataproc_operator.py b/tests/contrib/operators/test_dataproc_operator.py index e5be30ad46bad..5445fcc58cce7 100644 --- a/tests/contrib/operators/test_dataproc_operator.py +++ b/tests/contrib/operators/test_dataproc_operator.py @@ -31,6 +31,7 @@ DataprocClusterDeleteOperator, \ DataProcHadoopOperator, \ DataProcHiveOperator, \ + DataProcPigOperator, \ DataProcPySparkOperator, \ DataProcSparkOperator, \ DataprocWorkflowTemplateInstantiateInlineOperator, \ @@ -51,6 +52,7 @@ GCP_PROJECT_ID = 'test-project-id' NUM_WORKERS = 123 GCE_ZONE = 'us-central1-a' +SCALING_POLICY = 'test-scaling-policy' NETWORK_URI = '/projects/project_id/regions/global/net' SUBNETWORK_URI = '/projects/project_id/regions/global/subnet' INTERNAL_IP_ONLY = True @@ -119,6 +121,7 @@ def setUp(self): project_id=GCP_PROJECT_ID, num_workers=NUM_WORKERS, zone=GCE_ZONE, + autoscaling_policy=SCALING_POLICY, network_uri=NETWORK_URI, subnetwork_uri=SUBNETWORK_URI, internal_ip_only=INTERNAL_IP_ONLY, @@ -174,6 +177,7 @@ def test_init(self): self.assertEqual(dataproc_operator.idle_delete_ttl, IDLE_DELETE_TTL) self.assertEqual(dataproc_operator.auto_delete_time, AUTO_DELETE_TIME) self.assertEqual(dataproc_operator.auto_delete_ttl, AUTO_DELETE_TTL) + self.assertEqual(dataproc_operator.autoscaling_policy, SCALING_POLICY) def test_get_init_action_timeout(self): for suffix, dataproc_operator in enumerate(self.dataproc_operators): @@ -208,6 +212,8 @@ def test_build_cluster_data(self): "321s") self.assertEqual(cluster_data['config']['lifecycleConfig']['autoDeleteTime'], "2017-06-07T00:00:00.000000Z") + self.assertEqual(cluster_data['config']['autoscalingConfig']['policyUri'], + SCALING_POLICY) # test whether the default airflow-version label has been properly # set to the dataproc operator. merged_labels = {} @@ -394,7 +400,9 @@ def test_create_cluster(self): 'secondaryWorkerConfig': {}, 'softwareConfig': {}, 'lifecycleConfig': {}, - 'encryptionConfig': {}}, + 'encryptionConfig': {}, + 'autoscalingConfig': {}, + }, 'labels': {'airflow-version': mock.ANY}}) hook.wait.assert_called_once_with(self.operation) @@ -615,6 +623,32 @@ def test_dataproc_job_id_is_set(): _assert_dataproc_job_id(mock_hook, dataproc_task) +class DataProcPigOperatorTest(unittest.TestCase): + @staticmethod + def test_hook_correct_region(): + with patch(HOOK) as mock_hook: + dataproc_task = DataProcPigOperator( + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + region=GCP_REGION + ) + + dataproc_task.execute(None) + mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, + GCP_REGION, mock.ANY) + + @staticmethod + def test_dataproc_job_id_is_set(): + with patch(HOOK) as mock_hook: + dataproc_task = DataProcPigOperator( + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + region=GCP_REGION + ) + + _assert_dataproc_job_id(mock_hook, dataproc_task) + + class DataProcPySparkOperatorTest(unittest.TestCase): # Unit test for the DataProcPySparkOperator @staticmethod