Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/providers/google/ADDITIONAL_INFO.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
| [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
| [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |


### The field names use the snake_case convention
Expand Down
118 changes: 62 additions & 56 deletions airflow/providers/google/cloud/hooks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
which allows you to connect to Google Cloud Tasks service,
performing actions to queues or tasks.
"""

from typing import Dict, List, Optional, Sequence, Tuple, Union

from google.api_core.retry import Retry
from google.cloud.tasks_v2 import CloudTasksClient, enums
from google.cloud.tasks_v2.types import FieldMask, Queue, Task
from google.cloud.tasks_v2 import CloudTasksClient
from google.cloud.tasks_v2.types import Queue, Task
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
Expand Down Expand Up @@ -120,20 +122,19 @@ def create_queue(
client = self.get_conn()

if queue_name:
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
if isinstance(task_queue, Queue):
task_queue.name = full_queue_name
elif isinstance(task_queue, dict):
task_queue['name'] = full_queue_name
else:
raise AirflowException('Unable to set queue_name.')
full_location_path = CloudTasksClient.location_path(project_id, location)
full_location_path = f"projects/{project_id}/locations/{location}"
return client.create_queue(
parent=full_location_path,
queue=task_queue,
request={'parent': full_location_path, 'queue': task_queue},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a strong opinion, but I think doing this before function invocation may increase readability, WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about that. For full readability this would require creating a new variable name, as modifying the parameter content may be ambiguous as some languages have special behavior. Luckily not in Python, but this can still arouse mixed feelings among polyglots, including mine. The new variable in this case would complicate the code unnecessarily.

)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -167,7 +168,7 @@ def update_queue(
:param update_mask: A mast used to specify which fields of the queue are being updated.
If empty, then all fields will be updated.
If a dict is provided, it must be of the same form as the protobuf message.
:type update_mask: dict or google.cloud.tasks_v2.types.FieldMask
:type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -182,19 +183,18 @@ def update_queue(
client = self.get_conn()

if queue_name and location:
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
if isinstance(task_queue, Queue):
task_queue.name = full_queue_name
elif isinstance(task_queue, dict):
task_queue['name'] = full_queue_name
else:
raise AirflowException('Unable to set queue_name.')
return client.update_queue(
queue=task_queue,
update_mask=update_mask,
request={'queue': task_queue, 'update_mask': update_mask},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -230,8 +230,10 @@ def get_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.get_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.get_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def list_queues(
Expand Down Expand Up @@ -270,14 +272,12 @@ def list_queues(
"""
client = self.get_conn()

full_location_path = CloudTasksClient.location_path(project_id, location)
full_location_path = f"projects/{project_id}/locations/{location}"
queues = client.list_queues(
parent=full_location_path,
filter_=results_filter,
page_size=page_size,
request={'parent': full_location_path, 'filter': results_filter, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return list(queues)

Expand Down Expand Up @@ -313,8 +313,10 @@ def delete_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
client.delete_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
client.delete_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def purge_queue(
Expand Down Expand Up @@ -349,8 +351,10 @@ def purge_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.purge_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.purge_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def pause_queue(
Expand Down Expand Up @@ -385,8 +389,10 @@ def pause_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.pause_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.pause_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def resume_queue(
Expand Down Expand Up @@ -421,8 +427,10 @@ def resume_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.resume_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.resume_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def create_task(
Expand All @@ -432,7 +440,7 @@ def create_task(
task: Union[Dict, Task],
project_id: str,
task_name: Optional[str] = None,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
Expand All @@ -455,7 +463,7 @@ def create_task(
:type task_name: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -470,21 +478,21 @@ def create_task(
client = self.get_conn()

if task_name:
full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
full_task_name = (
f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
)
if isinstance(task, Task):
task.name = full_task_name
elif isinstance(task, dict):
task['name'] = full_task_name
else:
raise AirflowException('Unable to set task_name.')
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.create_task(
parent=full_queue_name,
task=task,
response_view=response_view,
request={'parent': full_queue_name, 'task': task, 'response_view': response_view},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -494,7 +502,7 @@ def get_task(
queue_name: str,
task_name: str,
project_id: str,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
Expand All @@ -513,7 +521,7 @@ def get_task(
:type project_id: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -527,13 +535,12 @@ def get_task(
"""
client = self.get_conn()

full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
return client.get_task(
name=full_task_name,
response_view=response_view,
request={'name': full_task_name, 'response_view': response_view},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -542,7 +549,7 @@ def list_tasks(
location: str,
queue_name: str,
project_id: str,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
page_size: Optional[int] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
Expand All @@ -560,7 +567,7 @@ def list_tasks(
:type project_id: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param page_size: (Optional) The maximum number of resources contained in the
underlying API response.
:type page_size: int
Expand All @@ -576,14 +583,12 @@ def list_tasks(
:rtype: list[google.cloud.tasks_v2.types.Task]
"""
client = self.get_conn()
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
tasks = client.list_tasks(
parent=full_queue_name,
response_view=response_view,
page_size=page_size,
request={'parent': full_queue_name, 'response_view': response_view, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return list(tasks)

Expand Down Expand Up @@ -622,8 +627,10 @@ def delete_task(
"""
client = self.get_conn()

full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
client.delete_task(name=full_task_name, retry=retry, timeout=timeout, metadata=metadata)
full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
client.delete_task(
request={'name': full_task_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def run_task(
Expand All @@ -632,7 +639,7 @@ def run_task(
queue_name: str,
task_name: str,
project_id: str,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
Expand All @@ -651,7 +658,7 @@ def run_task(
:type project_id: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -665,11 +672,10 @@ def run_task(
"""
client = self.get_conn()

full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
return client.run_task(
name=full_task_name,
response_view=response_view,
request={'name': full_task_name, 'response_view': response_view},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
Loading