Skip to content

Commit

Permalink
Optimize deferrable mode execution (#30920)
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv committed Apr 28, 2023
1 parent 0c28ed0 commit a3741e0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
29 changes: 15 additions & 14 deletions airflow/providers/google/cloud/sensors/gcs.py
Expand Up @@ -231,20 +231,21 @@ def execute(self, context: Context) -> None:
if self.deferrable is False:
super().execute(context)
else:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=GCSCheckBlobUpdateTimeTrigger(
bucket=self.bucket,
object_name=self.object,
target_date=self.ts_func(context),
poke_interval=self.poke_interval,
google_cloud_conn_id=self.google_cloud_conn_id,
hook_params={
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)
if not self.poke(context=context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=GCSCheckBlobUpdateTimeTrigger(
bucket=self.bucket,
object_name=self.object,
target_date=self.ts_func(context),
poke_interval=self.poke_interval,
google_cloud_conn_id=self.google_cloud_conn_id,
hook_params={
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
"""Callback for when the trigger fires."""
Expand Down
23 changes: 19 additions & 4 deletions tests/providers/google/cloud/sensors/test_gcs.py
Expand Up @@ -39,7 +39,6 @@
GCSCheckBlobUpdateTimeTrigger,
GCSPrefixBlobTrigger,
)
from tests.providers.google.cloud.utils.airflow_util import create_context

TEST_BUCKET = "TEST_BUCKET"

Expand Down Expand Up @@ -247,6 +246,21 @@ def test_should_pass_argument_to_hook(self, mock_hook):
mock_hook.return_value.is_updated_after.assert_called_once_with(TEST_BUCKET, TEST_OBJECT, mock.ANY)
assert result is True

@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor.defer")
def test_gcs_object_update_sensor_finish_before_deferred(self, mock_defer, mock_hook):
task = GCSObjectUpdateSensor(
task_id="task-id",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
deferrable=True,
)
mock_hook.return_value.is_updated_after.return_value = True
task.execute(mock.MagicMock())
assert not mock_defer.called


class TestGCSObjectUpdateSensorAsync:
OPERATOR = GCSObjectUpdateSensor(
Expand All @@ -257,14 +271,15 @@ class TestGCSObjectUpdateSensorAsync:
deferrable=True,
)

def test_gcs_object_update_sensor_async(self, context):
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
def test_gcs_object_update_sensor_async(self, mock_hook):
"""
Asserts that a task is deferred and a GCSBlobTrigger will be fired
when the GCSObjectUpdateSensorAsync is executed.
"""

mock_hook.return_value.is_updated_after.return_value = False
with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(create_context(self.OPERATOR))
self.OPERATOR.execute(mock.MagicMock())
assert isinstance(
exc.value.trigger, GCSCheckBlobUpdateTimeTrigger
), "Trigger is not a GCSCheckBlobUpdateTimeTrigger"
Expand Down

0 comments on commit a3741e0

Please sign in to comment.