From f0de0000bda9861f4058ac41fc564d876f97b1c7 Mon Sep 17 00:00:00 2001 From: ryanyuan Date: Mon, 10 Jun 2019 23:17:08 +1000 Subject: [PATCH] [AIRFLOW-4746] Implement GCP Cloud Tasks' Hook and Operators Implement GCP Cloud Tasks' Hook and Operators --- .../contrib/example_dags/example_gcp_tasks.py | 89 ++ airflow/contrib/hooks/gcp_tasks_hook.py | 699 +++++++++++++ .../contrib/operators/gcp_tasks_operator.py | 945 ++++++++++++++++++ docs/integration.rst | 45 + setup.py | 1 + tests/contrib/hooks/test_gcp_tasks_hook.py | 291 ++++++ .../operators/test_gcp_tasks_operator.py | 309 ++++++ .../test_gcp_tasks_operator_system.py | 37 + tests/contrib/utils/gcp_authenticator.py | 2 +- 9 files changed, 2417 insertions(+), 1 deletion(-) create mode 100644 airflow/contrib/example_dags/example_gcp_tasks.py create mode 100644 airflow/contrib/hooks/gcp_tasks_hook.py create mode 100644 airflow/contrib/operators/gcp_tasks_operator.py create mode 100644 tests/contrib/hooks/test_gcp_tasks_hook.py create mode 100644 tests/contrib/operators/test_gcp_tasks_operator.py create mode 100644 tests/contrib/operators/test_gcp_tasks_operator_system.py diff --git a/airflow/contrib/example_dags/example_gcp_tasks.py b/airflow/contrib/example_dags/example_gcp_tasks.py new file mode 100644 index 0000000000000..718a86274b8fe --- /dev/null +++ b/airflow/contrib/example_dags/example_gcp_tasks.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates, gets, lists, updates, purges, pauses, resumes +and deletes Queues and creates, gets, lists, runs and deletes Tasks in the Google +Cloud Tasks service in the Google Cloud Platform. +""" + + +from datetime import datetime, timedelta + +from google.api_core.retry import Retry +from google.cloud.tasks_v2.types import Queue +from google.protobuf import timestamp_pb2 + +import airflow +from airflow.contrib.operators.gcp_tasks_operator import ( + CloudTasksQueueCreateOperator, + CloudTasksTaskCreateOperator, + CloudTasksTaskRunOperator, +) +from airflow.models import DAG + +default_args = {"start_date": airflow.utils.dates.days_ago(1)} +timestamp = timestamp_pb2.Timestamp() +timestamp.FromDatetime(datetime.now() + timedelta(hours=12)) # pylint: disable=no-member + +LOCATION = "asia-east2" +QUEUE_ID = "cloud-tasks-queue" +TASK_NAME = "task-to-run" + + +TASK = { + "app_engine_http_request": { # Specify the type of request. + "http_method": "POST", + "relative_uri": "/example_task_handler", + "body": "Hello".encode(), + }, + "schedule_time": timestamp, +} + +with DAG("example_gcp_tasks", default_args=default_args, schedule_interval=None) as dag: + + create_queue = CloudTasksQueueCreateOperator( + location=LOCATION, + task_queue=Queue(), + queue_name=QUEUE_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="create_queue", + ) + + create_task_to_run = CloudTasksTaskCreateOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task=TASK, + task_name=TASK_NAME, + retry=Retry(maximum=10.0), + timeout=5, + task_id="create_task_to_run", + ) + + run_task = CloudTasksTaskRunOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + retry=Retry(maximum=10.0), + timeout=5, + task_id="run_task", + ) + + create_queue >> create_task_to_run >> run_task diff --git a/airflow/contrib/hooks/gcp_tasks_hook.py b/airflow/contrib/hooks/gcp_tasks_hook.py new file mode 100644 index 0000000000000..99f2ee85ce580 --- /dev/null +++ b/airflow/contrib/hooks/gcp_tasks_hook.py @@ -0,0 +1,699 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains a CloudTasksHook +which allows you to connect to GCP Cloud Tasks service, +performing actions to queues or tasks. +""" + +from google.cloud.tasks_v2 import CloudTasksClient +from google.cloud.tasks_v2.types import Queue, Task + +from airflow import AirflowException +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook + + +class CloudTasksHook(GoogleCloudBaseHook): + """ + Hook for Google Cloud Tasks APIs. Cloud Tasks allows developers to manage + the execution of background work in their applications. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + + def __init__(self, gcp_conn_id="google_cloud_default", delegate_to=None): + super().__init__(gcp_conn_id, delegate_to) + self._client = None + + def get_conn(self): + """ + Provides a client for interacting with the Cloud Tasks API. + + :return: GCP Cloud Tasks API Client + :rtype: google.cloud.tasks_v2.CloudTasksClient + """ + if not self._client: + self._client = CloudTasksClient( + credentials=self._get_credentials(), + client_info=self.client_info + ) + return self._client + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def create_queue( + self, + location, + task_queue, + project_id=None, + queue_name=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Creates a queue in Cloud Tasks. + + :param location: The location name in which the queue will be created. + :type location: str + :param task_queue: The task queue to create. + Queue's name cannot be the same as an existing queue. + If a dict is provided, it must be of the same form as the protobuf message Queue. + :type task_queue: dict or class google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Queue + """ + + client = self.get_conn() + + if queue_name: + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + if isinstance(task_queue, Queue): + task_queue.name = full_queue_name + elif isinstance(task_queue, dict): + task_queue['name'] = full_queue_name + else: + raise AirflowException('Unable to set queue_name.') + full_location_path = CloudTasksClient.location_path(project_id, location) + return client.create_queue( + parent=full_location_path, + queue=task_queue, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def update_queue( + self, + task_queue, + project_id=None, + location=None, + queue_name=None, + update_mask=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Updates a queue in Cloud Tasks. + + :param task_queue: The task queue to update. + This method creates the queue if it does not exist and updates the queue if + it does exist. The queue's name must be specified. + :type task_queue: dict or class google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param location: (Optional) The location name in which the queue will be updated. + If provided, it will be used to construct the full queue path. + :type location: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param update_mask: A mast used to specify which fields of the queue are being updated. + If empty, then all fields will be updated. + If a dict is provided, it must be of the same form as the protobuf message. + :type update_mask: dict or class google.cloud.tasks_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Queue + """ + + client = self.get_conn() + + if queue_name and location: + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + if isinstance(task_queue, Queue): + task_queue.name = full_queue_name + elif isinstance(task_queue, dict): + task_queue['name'] = full_queue_name + else: + raise AirflowException('Unable to set queue_name.') + return client.update_queue( + queue=task_queue, + update_mask=update_mask, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def get_queue( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Gets a queue from Cloud Tasks. + + :param location: The location name in which the queue was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Queue + """ + + client = self.get_conn() + + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + return client.get_queue( + name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def list_queues( + self, + location, + project_id=None, + results_filter=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Lists queues from Cloud Tasks. + + :param location: The location name in which the queues were created. + :type location: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param results_filter: (Optional) Filter used to specify a subset of queues. + :type results_filter: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + client = self.get_conn() + + full_location_path = CloudTasksClient.location_path(project_id, location) + queues = client.list_queues( + parent=full_location_path, + filter_=results_filter, + page_size=page_size, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return list(queues) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def delete_queue( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Deletes a queue from Cloud Tasks, even if it has tasks in it. + + :param location: The location name in which the queue will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + """ + + client = self.get_conn() + + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + client.delete_queue( + name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def purge_queue( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Purges a queue by deleting all of its tasks from Cloud Tasks. + + :param location: The location name in which the queue will be purged. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + client = self.get_conn() + + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + return client.purge_queue( + name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def pause_queue( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Pauses a queue in Cloud Tasks. + + :param location: The location name in which the queue will be paused. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + client = self.get_conn() + + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + return client.pause_queue( + name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def resume_queue( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Resumes a queue in Cloud Tasks. + + :param location: The location name in which the queue will be resumed. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + client = self.get_conn() + + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + return client.resume_queue( + name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def create_task( + self, + location, + queue_name, + task, + project_id=None, + task_name=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Creates a task in Cloud Tasks. + + :param location: The location name in which the task will be created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task: The task to add. + If a dict is provided, it must be of the same form as the protobuf message Task. + :type task: dict or class google.cloud.tasks_v2.types.Task + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param task_name: (Optional) The task's name. + If provided, it will be used to construct the full task path. + :type task_name: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Task + """ + + client = self.get_conn() + + if task_name: + full_task_name = CloudTasksClient.task_path( + project_id, location, queue_name, task_name + ) + if isinstance(task, Task): + task.name = full_task_name + elif isinstance(task, dict): + task['name'] = full_task_name + else: + raise AirflowException('Unable to set task_name.') + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + return client.create_task( + parent=full_queue_name, + task=task, + response_view=response_view, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def get_task( + self, + location, + queue_name, + task_name, + project_id=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Gets a task from Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Task + """ + + client = self.get_conn() + + full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) + return client.get_task( + name=full_task_name, + response_view=response_view, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def list_tasks( + self, + location, + queue_name, + project_id=None, + response_view=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Lists the tasks in Cloud Tasks. + + :param location: The location name in which the tasks were created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Task] + """ + + client = self.get_conn() + full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + tasks = client.list_tasks( + parent=full_queue_name, + response_view=response_view, + page_size=page_size, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return list(tasks) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def delete_task( + self, + location, + queue_name, + task_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Deletes a task from Cloud Tasks. + + :param location: The location name in which the task will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + """ + + client = self.get_conn() + + full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) + client.delete_task( + name=full_task_name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def run_task( + self, + location, + queue_name, + task_name, + project_id=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ): + """ + Forces to run a task in Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Task + """ + + client = self.get_conn() + + full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) + return client.run_task( + name=full_task_name, + response_view=response_view, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/airflow/contrib/operators/gcp_tasks_operator.py b/airflow/contrib/operators/gcp_tasks_operator.py new file mode 100644 index 0000000000000..a8b86a6ee70b3 --- /dev/null +++ b/airflow/contrib/operators/gcp_tasks_operator.py @@ -0,0 +1,945 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains various GCP Cloud Tasks operators +which allow you to perform basic operations using +Cloud Tasks queues/tasks. +""" + +from airflow.contrib.hooks.gcp_tasks_hook import CloudTasksHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class CloudTasksQueueCreateOperator(BaseOperator): + """ + Creates a queue in Cloud Tasks. + + :param location: The location name in which the queue will be created. + :type location: str + :param task_queue: The task queue to create. + Queue's name cannot be the same as an existing queue. + If a dict is provided, it must be of the same form as the protobuf message Queue. + :type task_queue: dict or class google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: google.cloud.tasks_v2.types.Queue + """ + + template_fields = ( + "task_queue", + "project_id", + "location", + "queue_name", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location, + task_queue, + project_id=None, + queue_name=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.task_queue = task_queue + self.project_id = project_id + self.queue_name = queue_name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.create_queue( + location=self.location, + task_queue=self.task_queue, + project_id=self.project_id, + queue_name=self.queue_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueueUpdateOperator(BaseOperator): + """ + Updates a queue in Cloud Tasks. + + :param task_queue: The task queue to update. + This method creates the queue if it does not exist and updates the queue if + it does exist. The queue's name must be specified. + :type task_queue: dict or class google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param location: (Optional) The location name in which the queue will be updated. + If provided, it will be used to construct the full queue path. + :type location: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param update_mask: A mast used to specify which fields of the queue are being updated. + If empty, then all fields will be updated. + If a dict is provided, it must be of the same form as the protobuf message. + :type update_mask: dict or class google.cloud.tasks_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: google.cloud.tasks_v2.types.Queue + """ + + template_fields = ( + "task_queue", + "project_id", + "location", + "queue_name", + "update_mask", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + task_queue, + project_id=None, + location=None, + queue_name=None, + update_mask=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.task_queue = task_queue + self.project_id = project_id + self.location = location + self.queue_name = queue_name + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.update_queue( + task_queue=self.task_queue, + project_id=self.project_id, + location=self.location, + queue_name=self.queue_name, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueueGetOperator(BaseOperator): + """ + Gets a queue from Cloud Tasks. + + :param location: The location name in which the queue was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: google.cloud.tasks_v2.types.Queue + """ + + template_fields = ("location", "queue_name", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.get_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueuesListOperator(BaseOperator): + + """ + Lists queues from Cloud Tasks. + + :param location: The location name in which the queues were created. + :type location: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param results_filter: (Optional) Filter used to specify a subset of queues. + :type results_filter: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ("location", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location, + project_id=None, + results_filter=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.project_id = project_id + self.results_filter = results_filter + self.page_size = page_size + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.list_queues( + location=self.location, + project_id=self.project_id, + results_filter=self.results_filter, + page_size=self.page_size, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueueDeleteOperator(BaseOperator): + """ + Deletes a queue from Cloud Tasks, even if it has tasks in it. + + :param location: The location name in which the queue will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ("location", "queue_name", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + hook.delete_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueuePurgeOperator(BaseOperator): + """ + Purges a queue by deleting all of its tasks from Cloud Tasks. + + :param location: The location name in which the queue will be purged. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ("location", "queue_name", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.purge_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueuePauseOperator(BaseOperator): + """ + Pauses a queue in Cloud Tasks. + + :param location: The location name in which the queue will be paused. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ("location", "queue_name", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.pause_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueueResumeOperator(BaseOperator): + """ + Resumes a queue in Cloud Tasks. + + :param location: The location name in which the queue will be resumed. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ("location", "queue_name", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location, + queue_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.resume_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksTaskCreateOperator(BaseOperator): + """ + Creates a task in Cloud Tasks. + + :param location: The location name in which the task will be created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task: The task to add. + If a dict is provided, it must be of the same form as the protobuf message Task. + :type task: dict or class google.cloud.tasks_v2.types.Task + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param task_name: (Optional) The task's name. + If provided, it will be used to construct the full task path. + :type task_name: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: google.cloud.tasks_v2.types.Task + """ + + template_fields = ( + "task", + "project_id", + "location", + "queue_name", + "task_name", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location, + queue_name, + task, + project_id=None, + task_name=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): # pylint: disable=too-many-arguments + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.task = task + self.project_id = project_id + self.task_name = task_name + self.response_view = response_view + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.create_task( + location=self.location, + queue_name=self.queue_name, + task=self.task, + project_id=self.project_id, + task_name=self.task_name, + response_view=self.response_view, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksTaskGetOperator(BaseOperator): + """ + Gets a task from Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: google.cloud.tasks_v2.types.Task + """ + + template_fields = ( + "location", + "queue_name", + "task_name", + "project_id", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location, + queue_name, + task_name, + project_id=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.task_name = task_name + self.project_id = project_id + self.response_view = response_view + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.get_task( + location=self.location, + queue_name=self.queue_name, + task_name=self.task_name, + project_id=self.project_id, + response_view=self.response_view, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksTasksListOperator(BaseOperator): + """ + Lists the tasks in Cloud Tasks. + + :param location: The location name in which the tasks were created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: list[google.cloud.tasks_v2.types.Task] + """ + + template_fields = ("location", "queue_name", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location, + queue_name, + project_id=None, + response_view=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.response_view = response_view + self.page_size = page_size + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.list_tasks( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + response_view=self.response_view, + page_size=self.page_size, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksTaskDeleteOperator(BaseOperator): + """ + Deletes a task from Cloud Tasks. + + :param location: The location name in which the task will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "queue_name", + "task_name", + "project_id", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location, + queue_name, + task_name, + project_id=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.task_name = task_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + hook.delete_task( + location=self.location, + queue_name=self.queue_name, + task_name=self.task_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksTaskRunOperator(BaseOperator): + """ + Forces to run a task in Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the GCP project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.types.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: google.cloud.tasks_v2.types.Task + """ + + template_fields = ( + "location", + "queue_name", + "task_name", + "project_id", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location, + queue_name, + task_name, + project_id=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.queue_name = queue_name + self.task_name = task_name + self.project_id = project_id + self.response_view = response_view + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) + return hook.run_task( + location=self.location, + queue_name=self.queue_name, + task_name=self.task_name, + project_id=self.project_id, + response_view=self.response_view, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) diff --git a/docs/integration.rst b/docs/integration.rst index 172964f465640..ab68d2ead0e77 100644 --- a/docs/integration.rst +++ b/docs/integration.rst @@ -880,6 +880,51 @@ Google Cloud Data Loss Prevention (DLP) They also use :class:`airflow.contrib.hooks.gcp_dlp_hook.CloudDLPHook` to communicate with Google Cloud Platform. +Google Cloud Tasks +'''''''''''''''''' + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueueCreateOperator` + Creates a queue in Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueueUpdateOperator` + Updates a queue in Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueueGetOperator` + Gets a queue from Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueuesListOperator` + Lists queues from Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueueDeleteOperator` + Deletes a queue from Cloud Tasks, even if it has tasks in it. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueuePurgeOperator` + Purges a queue by deleting all of its tasks from Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueuePauseOperator` + Pauses a queue in Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksQueueResumeOperator` + Resumes a queue in Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksTaskCreateOperator` + Creates a task in Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksTaskGetOperator` + Gets a task from Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksTasksListOperator` + Lists the tasks in Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksTaskDeleteOperator` + Deletes a task from Cloud Tasks. + +:class:`airflow.contrib.operators.gcp_tasks_operator.CloudTasksTaskRunOperator` + Forces to run a task in Cloud Tasks. + +They also use :class:`airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook` to communicate with Google Cloud Platform. + + .. _Qubole: Qubole diff --git a/setup.py b/setup.py index 1a21a3a3234ca..8ca57fe73aadf 100644 --- a/setup.py +++ b/setup.py @@ -200,6 +200,7 @@ def write_version(filename: str = os.path.join(*["airflow", "git_version"])): 'google-cloud-translate>=1.5.0', 'google-cloud-videointelligence>=1.7.0', 'google-cloud-vision>=0.35.2', + 'google-cloud-tasks==1.1.0', 'google-cloud-texttospeech>=0.4.0', 'google-cloud-speech>=0.36.3', 'grpcio-gcp>=0.2.2', diff --git a/tests/contrib/hooks/test_gcp_tasks_hook.py b/tests/contrib/hooks/test_gcp_tasks_hook.py new file mode 100644 index 0000000000000..835473d71116d --- /dev/null +++ b/tests/contrib/hooks/test_gcp_tasks_hook.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import unittest +from typing import Dict, Any + +from google.cloud.tasks_v2.types import Queue, Task + +from airflow.contrib.hooks.gcp_tasks_hook import CloudTasksHook +from tests.compat import mock +from tests.contrib.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id + + +API_RESPONSE = {} # type: Dict[Any, Any] +PROJECT_ID = "test-project" +LOCATION = "asia-east2" +FULL_LOCATION_PATH = "projects/test-project/locations/asia-east2" +QUEUE_ID = "test-queue" +FULL_QUEUE_PATH = "projects/test-project/locations/asia-east2/queues/test-queue" +TASK_NAME = "test-task" +FULL_TASK_PATH = ( + "projects/test-project/locations/asia-east2/queues/test-queue/tasks/test-task" +) + + +class TestCloudTasksHook(unittest.TestCase): + def setUp(self): + with mock.patch( + "airflow.contrib.hooks." "gcp_api_base_hook.GoogleCloudBaseHook.__init__", + new=mock_base_gcp_hook_no_default_project_id, + ): + self.hook = CloudTasksHook(gcp_conn_id="test") + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.create_queue.return_value": API_RESPONSE}, # type: ignore + ) + def test_create_queue(self, get_conn): + result = self.hook.create_queue( + location=LOCATION, + task_queue=Queue(), + queue_name=QUEUE_ID, + project_id=PROJECT_ID, + ) + + self.assertIs(result, API_RESPONSE) + + get_conn.return_value.create_queue.assert_called_once_with( + parent=FULL_LOCATION_PATH, + queue=Queue(name=FULL_QUEUE_PATH), + retry=None, + timeout=None, + metadata=None, + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.update_queue.return_value": API_RESPONSE}, # type: ignore + ) + def test_update_queue(self, get_conn): + result = self.hook.update_queue( + task_queue=Queue(state=3), + location=LOCATION, + queue_name=QUEUE_ID, + project_id=PROJECT_ID, + ) + + self.assertIs(result, API_RESPONSE) + + get_conn.return_value.update_queue.assert_called_once_with( + queue=Queue(name=FULL_QUEUE_PATH, state=3), + update_mask=None, + retry=None, + timeout=None, + metadata=None, + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.get_queue.return_value": API_RESPONSE}, # type: ignore + ) + def test_get_queue(self, get_conn): + result = self.hook.get_queue( + location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID + ) + + self.assertIs(result, API_RESPONSE) + + get_conn.return_value.get_queue.assert_called_once_with( + name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.list_queues.return_value": API_RESPONSE}, # type: ignore + ) + def test_list_queues(self, get_conn): + result = self.hook.list_queues(location=LOCATION, project_id=PROJECT_ID) + + self.assertEqual(result, list(API_RESPONSE)) + + get_conn.return_value.list_queues.assert_called_once_with( + parent=FULL_LOCATION_PATH, + filter_=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.delete_queue.return_value": API_RESPONSE}, # type: ignore + ) + def test_delete_queue(self, get_conn): + result = self.hook.delete_queue( + location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID + ) + + self.assertEqual(result, None) + + get_conn.return_value.delete_queue.assert_called_once_with( + name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.purge_queue.return_value": API_RESPONSE}, # type: ignore + ) + def test_purge_queue(self, get_conn): + result = self.hook.purge_queue( + location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID + ) + + self.assertEqual(result, API_RESPONSE) + + get_conn.return_value.purge_queue.assert_called_once_with( + name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.pause_queue.return_value": API_RESPONSE}, # type: ignore + ) + def test_pause_queue(self, get_conn): + result = self.hook.pause_queue( + location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID + ) + + self.assertEqual(result, API_RESPONSE) + + get_conn.return_value.pause_queue.assert_called_once_with( + name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.resume_queue.return_value": API_RESPONSE}, # type: ignore + ) + def test_resume_queue(self, get_conn): + result = self.hook.resume_queue( + location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID + ) + + self.assertEqual(result, API_RESPONSE) + + get_conn.return_value.resume_queue.assert_called_once_with( + name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.create_task.return_value": API_RESPONSE}, # type: ignore + ) + def test_create_task(self, get_conn): + result = self.hook.create_task( + location=LOCATION, + queue_name=QUEUE_ID, + task=Task(), + project_id=PROJECT_ID, + task_name=TASK_NAME, + ) + + self.assertEqual(result, API_RESPONSE) + + get_conn.return_value.create_task.assert_called_once_with( + parent=FULL_QUEUE_PATH, + task=Task(name=FULL_TASK_PATH), + response_view=None, + retry=None, + timeout=None, + metadata=None, + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.get_task.return_value": API_RESPONSE}, # type: ignore + ) + def test_get_task(self, get_conn): + result = self.hook.get_task( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + project_id=PROJECT_ID, + ) + + self.assertEqual(result, API_RESPONSE) + + get_conn.return_value.get_task.assert_called_once_with( + name=FULL_TASK_PATH, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.list_tasks.return_value": API_RESPONSE}, # type: ignore + ) + def test_list_tasks(self, get_conn): + result = self.hook.list_tasks( + location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID + ) + + self.assertEqual(result, list(API_RESPONSE)) + + get_conn.return_value.list_tasks.assert_called_once_with( + parent=FULL_QUEUE_PATH, + response_view=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.delete_task.return_value": API_RESPONSE}, # type: ignore + ) + def test_delete_task(self, get_conn): + result = self.hook.delete_task( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + project_id=PROJECT_ID, + ) + + self.assertEqual(result, None) + + get_conn.return_value.delete_task.assert_called_once_with( + name=FULL_TASK_PATH, retry=None, timeout=None, metadata=None + ) + + @mock.patch( # type: ignore + "airflow.contrib.hooks.gcp_tasks_hook.CloudTasksHook.get_conn", + **{"return_value.run_task.return_value": API_RESPONSE}, # type: ignore + ) + def test_run_task(self, get_conn): + result = self.hook.run_task( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + project_id=PROJECT_ID, + ) + + self.assertEqual(result, API_RESPONSE) + + get_conn.return_value.run_task.assert_called_once_with( + name=FULL_TASK_PATH, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ) diff --git a/tests/contrib/operators/test_gcp_tasks_operator.py b/tests/contrib/operators/test_gcp_tasks_operator.py new file mode 100644 index 0000000000000..b9ebf96a60085 --- /dev/null +++ b/tests/contrib/operators/test_gcp_tasks_operator.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from google.cloud.tasks_v2.types import Queue, Task + +from airflow.contrib.operators.gcp_tasks_operator import ( + CloudTasksQueueCreateOperator, + CloudTasksTaskCreateOperator, + CloudTasksQueueDeleteOperator, + CloudTasksTaskDeleteOperator, + CloudTasksQueueGetOperator, + CloudTasksTaskGetOperator, + CloudTasksQueuesListOperator, + CloudTasksTasksListOperator, + CloudTasksQueuePauseOperator, + CloudTasksQueuePurgeOperator, + CloudTasksQueueResumeOperator, + CloudTasksTaskRunOperator, + CloudTasksQueueUpdateOperator, +) +from tests.compat import mock + +GCP_CONN_ID = "google_cloud_default" +PROJECT_ID = "test-project" +LOCATION = "asia-east2" +FULL_LOCATION_PATH = "projects/test-project/locations/asia-east2" +QUEUE_ID = "test-queue" +FULL_QUEUE_PATH = "projects/test-project/locations/asia-east2/queues/test-queue" +TASK_NAME = "test-task" +FULL_TASK_PATH = ( + "projects/test-project/locations/asia-east2/queues/test-queue/tasks/test-task" +) + + +class CloudTasksQueueCreateTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_create_queue(self, mock_hook): + mock_hook.return_value.create_queue.return_value = {} + operator = CloudTasksQueueCreateOperator( + location=LOCATION, task_queue=Queue(), task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.create_queue.assert_called_once_with( + location=LOCATION, + task_queue=Queue(), + project_id=None, + queue_name=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksQueueUpdateTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_update_queue(self, mock_hook): + mock_hook.return_value.update_queue.return_value = {} + operator = CloudTasksQueueUpdateOperator( + task_queue=Queue(name=FULL_QUEUE_PATH), task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.update_queue.assert_called_once_with( + task_queue=Queue(name=FULL_QUEUE_PATH), + project_id=None, + location=None, + queue_name=None, + update_mask=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksQueueGetTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_get_queue(self, mock_hook): + mock_hook.return_value.get_queue.return_value = {} + operator = CloudTasksQueueGetOperator( + location=LOCATION, queue_name=QUEUE_ID, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.get_queue.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksQueuesListTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_list_queues(self, mock_hook): + mock_hook.return_value.list_queues.return_value = {} + operator = CloudTasksQueuesListOperator(location=LOCATION, task_id="id") + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.list_queues.assert_called_once_with( + location=LOCATION, + project_id=None, + results_filter=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksQueueDeleteTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_delete_queue(self, mock_hook): + mock_hook.return_value.delete_queue.return_value = {} + operator = CloudTasksQueueDeleteOperator( + location=LOCATION, queue_name=QUEUE_ID, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.delete_queue.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksQueuePurgeTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_delete_queue(self, mock_hook): + mock_hook.return_value.purge_queue.return_value = {} + operator = CloudTasksQueuePurgeOperator( + location=LOCATION, queue_name=QUEUE_ID, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.purge_queue.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksQueuePauseTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_pause_queue(self, mock_hook): + mock_hook.return_value.pause_queue.return_value = {} + operator = CloudTasksQueuePauseOperator( + location=LOCATION, queue_name=QUEUE_ID, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.pause_queue.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksQueueResumeTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_resume_queue(self, mock_hook): + mock_hook.return_value.resume_queue.return_value = {} + operator = CloudTasksQueueResumeOperator( + location=LOCATION, queue_name=QUEUE_ID, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.resume_queue.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksTaskCreateTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_create_task(self, mock_hook): + mock_hook.return_value.create_task.return_value = {} + operator = CloudTasksTaskCreateOperator( + location=LOCATION, queue_name=QUEUE_ID, task=Task(), task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.create_task.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + task=Task(), + project_id=None, + task_name=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksTaskGetTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_get_task(self, mock_hook): + mock_hook.return_value.get_task.return_value = {} + operator = CloudTasksTaskGetOperator( + location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.get_task.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + project_id=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksTasksListTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_list_tasks(self, mock_hook): + mock_hook.return_value.list_tasks.return_value = {} + operator = CloudTasksTasksListOperator( + location=LOCATION, queue_name=QUEUE_ID, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.list_tasks.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + project_id=None, + response_view=None, + page_size=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksTaskDeleteTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_delete_task(self, mock_hook): + mock_hook.return_value.delete_task.return_value = {} + operator = CloudTasksTaskDeleteOperator( + location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.delete_task.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ) + + +class CloudTasksTaskRunTest(unittest.TestCase): + @mock.patch("airflow.contrib.operators.gcp_tasks_operator.CloudTasksHook") + def test_run_task(self, mock_hook): + mock_hook.return_value.run_task.return_value = {} + operator = CloudTasksTaskRunOperator( + location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id" + ) + operator.execute(context=None) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) + mock_hook.return_value.run_task.assert_called_once_with( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + project_id=None, + response_view=None, + retry=None, + timeout=None, + metadata=None, + ) diff --git a/tests/contrib/operators/test_gcp_tasks_operator_system.py b/tests/contrib/operators/test_gcp_tasks_operator_system.py new file mode 100644 index 0000000000000..f07c5a3419262 --- /dev/null +++ b/tests/contrib/operators/test_gcp_tasks_operator_system.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from tests.contrib.utils.base_gcp_system_test_case import \ + SKIP_TEST_WARNING, DagGcpSystemTestCase +from tests.contrib.utils.gcp_authenticator import GCP_TASKS_KEY + + +@unittest.skipIf( + DagGcpSystemTestCase.skip_check(GCP_TASKS_KEY), SKIP_TEST_WARNING) +class GcpTasksExampleDagsSystemTest(DagGcpSystemTestCase): + def __init__(self, method_name='runTest'): + super().__init__( + method_name, + dag_id='example_gcp_tasks', + gcp_key=GCP_TASKS_KEY) + + def test_run_example_dag_function(self): + self._run_dag() diff --git a/tests/contrib/utils/gcp_authenticator.py b/tests/contrib/utils/gcp_authenticator.py index 796e1b8c3c462..b04db143c121d 100644 --- a/tests/contrib/utils/gcp_authenticator.py +++ b/tests/contrib/utils/gcp_authenticator.py @@ -37,7 +37,7 @@ GCP_GCS_KEY = 'gcp_gcs.json' GCP_GCS_TRANSFER_KEY = 'gcp_gcs_transfer.json' GCP_SPANNER_KEY = 'gcp_spanner.json' - +GCP_TASKS_KEY = 'gcp_tasks.json' KEYPATH_EXTRA = 'extra__google_cloud_platform__key_path' KEYFILE_DICT_EXTRA = 'extra__google_cloud_platform__keyfile_dict'