Skip to content

Commit

Permalink
Add deferrable mode to PubsubPullSensor (#31284)
Browse files Browse the repository at this point in the history
Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
  • Loading branch information
bkossakowska and Beata Kossakowska committed Jun 14, 2023
1 parent 9a2a5b0 commit a81ac70
Show file tree
Hide file tree
Showing 10 changed files with 591 additions and 12 deletions.
140 changes: 137 additions & 3 deletions airflow/providers/google/cloud/hooks/pubsub.py
Expand Up @@ -28,7 +28,7 @@
import warnings
from base64 import b64decode
from functools import cached_property
from typing import Sequence
from typing import Any, Sequence
from uuid import uuid4

from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
Expand All @@ -45,11 +45,16 @@
ReceivedMessage,
RetryPolicy,
)
from google.pubsub_v1.services.subscriber.async_client import SubscriberAsyncClient
from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)
from airflow.version import version


Expand Down Expand Up @@ -496,7 +501,6 @@ def pull(

self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path)
try:

response = subscriber.pull(
request={
"subscription": subscription_path,
Expand Down Expand Up @@ -569,3 +573,133 @@ def acknowledge(
)

self.log.info("Acknowledged ack_ids from subscription (path) %s", subscription_path)


class PubSubAsyncHook(GoogleBaseAsyncHook):
"""Class to get asynchronous hook for Google Cloud PubSub."""

sync_hook_class = PubSubHook

def __init__(self, project_id: str | None = None, **kwargs: Any):
super().__init__(**kwargs)
self.project_id = project_id
self._client: SubscriberAsyncClient | None = None

async def _get_subscriber_client(self) -> SubscriberAsyncClient:
"""
Returns async connection to the Google PubSub
:return: Google Pub/Sub asynchronous client.
"""
if not self._client:
credentials = (await self.get_sync_hook()).get_credentials()
self._client = SubscriberAsyncClient(credentials=credentials, client_info=CLIENT_INFO)
return self._client

@GoogleBaseHook.fallback_to_default_project_id
async def acknowledge(
self,
subscription: str,
project_id: str,
ack_ids: list[str] | None = None,
messages: list[ReceivedMessage] | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> None:
"""
Acknowledges the messages associated with the ``ack_ids`` from Pub/Sub subscription.
:param subscription: the Pub/Sub subscription name to delete; do not
include the 'projects/{project}/topics/' prefix.
:param ack_ids: List of ReceivedMessage ackIds from a previous pull response.
Mutually exclusive with ``messages`` argument.
:param messages: List of ReceivedMessage objects to acknowledge.
Mutually exclusive with ``ack_ids`` argument.
:param project_id: Optional, the Google Cloud project name or ID in which to create the topic
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:param timeout: (Optional) The amount of time, in seconds, to wait for the request
to complete. Note that if retry is specified, the timeout applies to each
individual attempt.
:param metadata: (Optional) Additional metadata that is provided to the method.
"""
subscriber = await self._get_subscriber_client()
if ack_ids is not None and messages is None:
pass # use ack_ids as is
elif ack_ids is None and messages is not None:
ack_ids = [message.ack_id for message in messages] # extract ack_ids from messages
else:
raise ValueError("One and only one of 'ack_ids' and 'messages' arguments have to be provided")

subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
self.log.info("Acknowledging %d ack_ids from subscription (path) %s", len(ack_ids), subscription_path)

try:
await subscriber.acknowledge(
request={"subscription": subscription_path, "ack_ids": ack_ids},
retry=retry,
timeout=timeout,
metadata=metadata,
)
except (HttpError, GoogleAPICallError) as e:
raise PubSubException(
f"Error acknowledging {len(ack_ids)} messages pulled from subscription {subscription_path}",
e,
)
self.log.info("Acknowledged ack_ids from subscription (path) %s", subscription_path)

@GoogleBaseHook.fallback_to_default_project_id
async def pull(
self,
subscription: str,
max_messages: int,
project_id: str = PROVIDE_PROJECT_ID,
return_immediately: bool = False,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> list[ReceivedMessage]:
"""
Pulls up to ``max_messages`` messages from Pub/Sub subscription.
:param subscription: the Pub/Sub subscription name to pull from; do not
include the 'projects/{project}/topics/' prefix.
:param max_messages: The maximum number of messages to return from
the Pub/Sub API.
:param project_id: Optional, the Google Cloud project ID where the subscription exists.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param return_immediately: If set, the Pub/Sub API will immediately
return if no messages are available. Otherwise, the request will
block for an undisclosed, but bounded period of time
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:param timeout: (Optional) The amount of time, in seconds, to wait for the request
to complete. Note that if retry is specified, the timeout applies to each
individual attempt.
:param metadata: (Optional) Additional metadata that is provided to the method.
:return: A list of Pub/Sub ReceivedMessage objects each containing
an ``ackId`` property and a ``message`` property, which includes
the base64-encoded message content. See
https://cloud.google.com/pubsub/docs/reference/rpc/google.pubsub.v1#google.pubsub.v1.ReceivedMessage
"""
subscriber = await self._get_subscriber_client()
subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path)

try:
response = await subscriber.pull(
request={
"subscription": subscription_path,
"max_messages": max_messages,
"return_immediately": return_immediately,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
result = getattr(response, "received_messages", [])
self.log.info("Pulled %d messages from subscription (path) %s", len(result), subscription_path)
return result
except (HttpError, GoogleAPICallError) as e:
raise PubSubException(f"Error pulling messages from subscription {subscription_path}", e)
49 changes: 43 additions & 6 deletions airflow/providers/google/cloud/sensors/pubsub.py
Expand Up @@ -18,11 +18,14 @@
"""This module contains a Google PubSub sensor."""
from __future__ import annotations

from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence

from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,6 +82,7 @@ class PubSubPullSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run sensor in deferrable mode
"""

template_fields: Sequence[str] = (
Expand All @@ -98,6 +102,8 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None,
impersonation_chain: str | Sequence[str] | None = None,
poke_interval: float = 10.0,
deferrable: bool = False,
**kwargs,
) -> None:

Expand All @@ -109,14 +115,10 @@ def __init__(
self.ack_messages = ack_messages
self.messages_callback = messages_callback
self.impersonation_chain = impersonation_chain

self.deferrable = deferrable
self.poke_interval = poke_interval
self._return_value = None

def execute(self, context: Context) -> Any:
"""Overridden to allow messages to be passed."""
super().execute(context)
return self._return_value

def poke(self, context: Context) -> bool:
hook = PubSubHook(
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -143,6 +145,41 @@ def poke(self, context: Context) -> bool:

return bool(pulled_messages)

def execute(self, context: Context) -> None:
"""
Airflow runs this method on the worker and defers using the triggers
if deferrable is set to True.
"""
if not self.deferrable:
super().execute(context)
return self._return_value
else:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=PubsubPullTrigger(
project_id=self.project_id,
subscription=self.subscription,
max_messages=self.max_messages,
ack_messages=self.ack_messages,
messages_callback=self.messages_callback,
poke_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]:
"""
Callback for when the trigger fires; returns immediately.
Relies on trigger to throw a success event.
"""
if event["status"] == "success":
self.log.info("Sensor pulls messages: %s", event["message"])
return event["message"]
self.log.info("Sensor failed: %s", event["message"])
raise AirflowException(event["message"])

def _default_message_callback(
self,
pulled_messages: list[ReceivedMessage],
Expand Down
126 changes: 126 additions & 0 deletions airflow/providers/google/cloud/triggers/pubsub.py
@@ -0,0 +1,126 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains Google Cloud Pubsub triggers."""
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Sequence

from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.providers.google.cloud.hooks.pubsub import PubSubAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.utils.context import Context


class PubsubPullTrigger(BaseTrigger):
"""
Initialize the Pubsub Pull Trigger with needed parameters.
:param project_id: the Google Cloud project ID for the subscription (templated)
:param subscription: the Pub/Sub subscription name. Do not include the full subscription path.
:param max_messages: The maximum number of messages to retrieve per
PubSub pull request
:param ack_messages: If True, each message will be acknowledged
immediately rather than by any downstream tasks
:param gcp_conn_id: Reference to google cloud connection id
:param messages_callback: (Optional) Callback to process received messages.
It's return value will be saved to XCom.
If you are pulling large messages, you probably want to provide a custom callback.
If not provided, the default implementation will convert `ReceivedMessage` objects
into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function.
:param poke_interval: polling period in seconds to check for the status
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

def __init__(
self,
project_id: str,
subscription: str,
max_messages: int,
ack_messages: bool,
gcp_conn_id: str,
messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None,
poke_interval: float = 10.0,
impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__()
self.project_id = project_id
self.subscription = subscription
self.max_messages = max_messages
self.ack_messages = ack_messages
self.messages_callback = messages_callback
self.poke_interval = poke_interval
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook = PubSubAsyncHook()

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes PubsubPullTrigger arguments and classpath."""
return (
"airflow.providers.google.cloud.triggers.pubsub.PubsubPullTrigger",
{
"project_id": self.project_id,
"subscription": self.subscription,
"max_messages": self.max_messages,
"ack_messages": self.ack_messages,
"messages_callback": self.messages_callback,
"poke_interval": self.poke_interval,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
try:
pulled_messages = None
while True:
if pulled_messages:
if self.ack_messages:
await self.message_acknowledgement(pulled_messages)
yield TriggerEvent({"status": "success", "message": pulled_messages})
else:
yield TriggerEvent({"status": "success", "message": pulled_messages})
else:
pulled_messages = await self.hook.pull(
project_id=self.project_id,
subscription=self.subscription,
max_messages=self.max_messages,
return_immediately=True,
)
self.log.info("Sleeping for %s seconds.", self.poke_interval)
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
return

async def message_acknowledgement(self, pulled_messages):
await self.hook.acknowledge(
project_id=self.project_id,
subscription=self.subscription,
messages=pulled_messages,
)
self.log.info("Acknowledged ack_ids from subscription %s", self.subscription)
3 changes: 3 additions & 0 deletions airflow/providers/google/provider.yaml
Expand Up @@ -864,6 +864,9 @@ triggers:
- integration-name: Google Machine Learning Engine
python-modules:
- airflow.providers.google.cloud.triggers.mlengine
- integration-name: Google Cloud Pub/Sub
python-modules:
- airflow.providers.google.cloud.triggers.pubsub

transfers:
- source-integration-name: Presto
Expand Down

0 comments on commit a81ac70

Please sign in to comment.