Skip to content

Commit

Permalink
Add deferrable mode to GCSUploadSessionCompleteSensor (#31081)
Browse files Browse the repository at this point in the history
* add deferrable mode to GCSUploadSessionCompleteSensor

* Add tests

* Fix tests

* Apply review suggestions

* Apply review suggestion

* Add docs

* Apply review suggestions
  • Loading branch information
phanikumv committed May 25, 2023
1 parent 28f2e70 commit 5ae9728
Show file tree
Hide file tree
Showing 6 changed files with 465 additions and 1 deletion.
40 changes: 40 additions & 0 deletions airflow/providers/google/cloud/sensors/gcs.py
Expand Up @@ -33,6 +33,7 @@
GCSBlobTrigger,
GCSCheckBlobUpdateTimeTrigger,
GCSPrefixBlobTrigger,
GCSUploadSessionTrigger,
)
from airflow.sensors.base import BaseSensorOperator, poke_mode_only

Expand Down Expand Up @@ -390,6 +391,7 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run sensor in deferrable mode
"""

template_fields: Sequence[str] = (
Expand All @@ -409,6 +411,7 @@ def __init__(
allow_delete: bool = True,
google_cloud_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
**kwargs,
) -> None:

Expand All @@ -427,6 +430,7 @@ def __init__(
self.last_activity_time = None
self.impersonation_chain = impersonation_chain
self.hook: GCSHook | None = None
self.deferrable = deferrable

def _get_gcs_hook(self) -> GCSHook | None:
if not self.hook:
Expand Down Expand Up @@ -514,3 +518,39 @@ def poke(self, context: Context) -> bool:
return self.is_bucket_updated(
set(self._get_gcs_hook().list(self.bucket, prefix=self.prefix)) # type: ignore[union-attr]
)

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
hook_params = {"impersonation_chain": self.impersonation_chain}

if not self.deferrable:
return super().execute(context)

if not self.poke(context=context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=GCSUploadSessionTrigger(
bucket=self.bucket,
prefix=self.prefix,
poke_interval=self.poke_interval,
google_cloud_conn_id=self.google_cloud_conn_id,
inactivity_period=self.inactivity_period,
min_objects=self.min_objects,
previous_objects=self.previous_objects,
allow_delete=self.allow_delete,
hook_params=hook_params,
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "success":
return event["message"]
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")
158 changes: 158 additions & 0 deletions airflow/providers/google/cloud/triggers/gcs.py
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import asyncio
import os
from datetime import datetime
from typing import Any, AsyncIterator

Expand Down Expand Up @@ -281,3 +282,160 @@ async def _list_blobs_with_prefix(self, hook: GCSAsyncHook, bucket_name: str, pr
bucket = client.get_bucket(bucket_name)
object_response = await bucket.list_blobs(prefix=prefix)
return object_response


class GCSUploadSessionTrigger(GCSPrefixBlobTrigger):
"""
Checks for changes in the number of objects at prefix in Google Cloud Storage
bucket and returns Trigger Event if the inactivity period has passed with no
increase in the number of objects.
:param bucket: The Google Cloud Storage bucket where the objects are.
expected.
:param prefix: The name of the prefix to check in the Google cloud
storage bucket.
:param poke_interval: polling period in seconds to check
:param inactivity_period: The total seconds of inactivity to designate
an upload session is over. Note, this mechanism is not real time and
this operator may not return until a interval after this period
has passed with no additional objects sensed.
:param min_objects: The minimum number of objects needed for upload session
to be considered valid.
:param previous_objects: The set of object ids found during the last poke.
:param allow_delete: Should this sensor consider objects being deleted
between intervals valid behavior. If true a warning message will be logged
when this happens. If false an error will be raised.
:param google_cloud_conn_id: The connection ID to use when connecting
to Google Cloud Storage.
"""

def __init__(
self,
bucket: str,
prefix: str,
poke_interval: float,
google_cloud_conn_id: str,
hook_params: dict[str, Any],
inactivity_period: float = 60 * 60,
min_objects: int = 1,
previous_objects: set[str] | None = None,
allow_delete: bool = True,
):
super().__init__(
bucket=bucket,
prefix=prefix,
poke_interval=poke_interval,
google_cloud_conn_id=google_cloud_conn_id,
hook_params=hook_params,
)
self.inactivity_period = inactivity_period
self.min_objects = min_objects
self.previous_objects = previous_objects if previous_objects else set()
self.inactivity_seconds = 0.0
self.allow_delete = allow_delete
self.last_activity_time: datetime | None = None

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes GCSUploadSessionTrigger arguments and classpath."""
return (
"airflow.providers.google.cloud.triggers.gcs.GCSUploadSessionTrigger",
{
"bucket": self.bucket,
"prefix": self.prefix,
"poke_interval": self.poke_interval,
"google_cloud_conn_id": self.google_cloud_conn_id,
"hook_params": self.hook_params,
"inactivity_period": self.inactivity_period,
"min_objects": self.min_objects,
"previous_objects": self.previous_objects,
"allow_delete": self.allow_delete,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Simple loop until no change in any new files or deleted in list blob is
found for the inactivity_period.
"""
try:
hook = self._get_async_hook()
while True:
list_blobs = await self._list_blobs_with_prefix(
hook=hook, bucket_name=self.bucket, prefix=self.prefix
)
res = self._is_bucket_updated(set(list_blobs))
if res["status"] in ("success", "error"):
yield TriggerEvent(res)
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
return

def _get_time(self) -> datetime:
"""
This is just a wrapper of datetime.datetime.now to simplify mocking in the
unittests.
"""
return datetime.now()

def _is_bucket_updated(self, current_objects: set[str]) -> dict[str, str]:
"""
Checks whether new objects have been uploaded and the inactivity_period
has passed and updates the state of the sensor accordingly.
:param current_objects: set of object ids in bucket during last check.
"""
current_num_objects = len(current_objects)
if current_objects > self.previous_objects:
# When new objects arrived, reset the inactivity_seconds
# and update previous_objects for the next check interval.
self.log.info(
"New objects found at %s resetting last_activity_time.",
os.path.join(self.bucket, self.prefix),
)
self.log.debug("New objects: %s", "\n".join(current_objects - self.previous_objects))
self.last_activity_time = self._get_time()
self.inactivity_seconds = 0
self.previous_objects = current_objects
return {"status": "pending"}

if self.previous_objects - current_objects:
# During the last interval check objects were deleted.
if self.allow_delete:
self.previous_objects = current_objects
self.last_activity_time = self._get_time()
self.log.warning(
"%s Objects were deleted during the last interval."
" Updating the file counter and resetting last_activity_time.",
self.previous_objects - current_objects,
)
return {"status": "pending"}
return {
"status": "error",
"message": "Illegal behavior: objects were deleted in between check intervals",
}
if self.last_activity_time:
self.inactivity_seconds = (self._get_time() - self.last_activity_time).total_seconds()
else:
# Handles the first check where last inactivity time is None.
self.last_activity_time = self._get_time()
self.inactivity_seconds = 0

if self.inactivity_seconds >= self.inactivity_period:
path = os.path.join(self.bucket, self.prefix)

if current_num_objects >= self.min_objects:
success_message = (
"SUCCESS: Sensor found %s objects at %s. Waited at least %s "
"seconds, with no new objects dropped."
)
self.log.info(success_message, current_num_objects, path, self.inactivity_seconds)
return {
"status": "success",
"message": success_message % (current_num_objects, path, self.inactivity_seconds),
}

error_message = "FAILURE: Inactivity Period passed, not enough objects found in %s"
self.log.error(error_message, path)
return {"status": "error", "message": error_message % path}
return {"status": "pending"}
9 changes: 9 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/gcs.rst
Expand Up @@ -224,6 +224,15 @@ Use the :class:`~airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionComp
:start-after: [START howto_sensor_gcs_upload_session_complete_task]
:end-before: [END howto_sensor_gcs_upload_session_complete_task]

You can set the parameter ``deferrable`` to True if you want the worker slots to be freed up while sensor is running.


.. exampleinclude:: /../../tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_gcs_upload_session_async_task]
:end-before: [END howto_sensor_gcs_upload_session_async_task]

.. _howto/sensor:GCSObjectUpdateSensor:

GCSObjectUpdateSensor
Expand Down
45 changes: 45 additions & 0 deletions tests/providers/google/cloud/sensors/test_gcs.py
Expand Up @@ -38,6 +38,7 @@
GCSBlobTrigger,
GCSCheckBlobUpdateTimeTrigger,
GCSPrefixBlobTrigger,
GCSUploadSessionTrigger,
)

TEST_BUCKET = "TEST_BUCKET"
Expand All @@ -56,6 +57,10 @@

MOCK_DATE_ARRAY = [datetime(2019, 2, 24, 12, 0, 0) - i * timedelta(seconds=10) for i in range(25)]

TEST_INACTIVITY_PERIOD = 5

TEST_MIN_OBJECTS = 1


@pytest.fixture()
def context():
Expand Down Expand Up @@ -518,3 +523,43 @@ def test_not_enough_objects(self):
self.sensor.is_bucket_updated(set())
assert self.sensor.inactivity_seconds == 10
assert not self.sensor.is_bucket_updated(set())


class TestGCSUploadSessionCompleteSensorAsync:
OPERATOR = GCSUploadSessionCompleteSensor(
task_id="gcs-obj-session",
bucket=TEST_BUCKET,
google_cloud_conn_id=TEST_GCP_CONN_ID,
prefix=TEST_OBJECT,
inactivity_period=TEST_INACTIVITY_PERIOD,
min_objects=TEST_MIN_OBJECTS,
deferrable=True,
)

@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
def test_gcs_upload_session_complete_sensor_async(self, mock_hook):
"""
Asserts that a task is deferred and a GCSUploadSessionTrigger will be fired
when the GCSUploadSessionCompleteSensorAsync is executed.
"""
mock_hook.return_value.is_bucket_updated.return_value = False
with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(mock.MagicMock())
assert isinstance(
exc.value.trigger, GCSUploadSessionTrigger
), "Trigger is not a GCSUploadSessionTrigger"

def test_gcs_upload_session_complete_sensor_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.OPERATOR.execute_complete(
context=context, event={"status": "error", "message": "test failure message"}
)

def test_gcs_upload_session_complete_sensor_async_execute_complete(self, context):
"""Asserts that execute complete is completed as expected"""

assert self.OPERATOR.execute_complete(
context=context, event={"status": "success", "message": "success"}
)

0 comments on commit 5ae9728

Please sign in to comment.