Skip to content

Commit

Permalink
Optimize deferred execution mode (#30946)
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv committed Apr 29, 2023
1 parent f3e82b2 commit b0a40bb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 36 deletions.
63 changes: 34 additions & 29 deletions airflow/providers/google/cloud/sensors/bigquery.py
Expand Up @@ -110,20 +110,24 @@ def poke(self, context: Context) -> bool:

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BigQueryTableExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
poll_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params={
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)
if not self.deferrable:
super().execute(context)
else:
if not self.poke(context=context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BigQueryTableExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
poll_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params={
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
"""
Expand Down Expand Up @@ -218,21 +222,22 @@ def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context)
else:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BigQueryTablePartitionExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
partition_id=self.partition_id,
poll_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params={
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)
if not self.poke(context=context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BigQueryTablePartitionExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
partition_id=self.partition_id,
poll_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params={
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
"""
Expand Down
57 changes: 50 additions & 7 deletions tests/providers/google/cloud/sensors/test_bigquery.py
Expand Up @@ -64,7 +64,24 @@ def test_passing_arguments_to_hook(self, mock_hook):
project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID
)

def test_execute_defered(self):
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor.defer")
def test_table_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook):
task = BigQueryTableExistenceSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
deferrable=True,
)
mock_hook.return_value.table_exists.return_value = True
task.execute(mock.MagicMock())
assert not mock_defer.called

@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
def test_execute_deferred(self, mock_hook):
"""
Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired
when the BigQueryTableExistenceAsyncSensor is executed.
Expand All @@ -76,13 +93,14 @@ def test_execute_defered(self):
table_id=TEST_TABLE_ID,
deferrable=True,
)
mock_hook.return_value.table_exists.return_value = False
with pytest.raises(TaskDeferred) as exc:
task.execute(context={})
task.execute(mock.MagicMock())
assert isinstance(
exc.value.trigger, BigQueryTableExistenceTrigger
), "Trigger is not a BigQueryTableExistenceTrigger"

def test_excute_defered_failure(self):
def test_execute_deferred_failure(self):
"""Tests that an AirflowException is raised in case of error event"""
task = BigQueryTableExistenceSensor(
task_id="task-id",
Expand Down Expand Up @@ -148,7 +166,9 @@ def test_passing_arguments_to_hook(self, mock_hook):
partition_id=TEST_PARTITION_ID,
)

def test_execute_with_deferrable_mode(self):
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryTablePartitionExistenceSensor.defer")
def test_table_partition_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook):
"""
Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired
when the BigQueryTablePartitionExistenceSensor is executed and deferrable is set to True.
Expand All @@ -161,6 +181,25 @@ def test_execute_with_deferrable_mode(self):
partition_id=TEST_PARTITION_ID,
deferrable=True,
)
mock_hook.return_value.table_partition_exists.return_value = True
task.execute(mock.MagicMock())
assert not mock_defer.called

@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
def test_execute_with_deferrable_mode(self, mock_hook):
"""
Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired
when the BigQueryTablePartitionExistenceSensor is executed and deferrable is set to True.
"""
task = BigQueryTablePartitionExistenceSensor(
task_id="test_task_id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
partition_id=TEST_PARTITION_ID,
deferrable=True,
)
mock_hook.return_value.table_partition_exists.return_value = False
with pytest.raises(TaskDeferred) as exc:
task.execute(context={})
assert isinstance(
Expand Down Expand Up @@ -228,7 +267,8 @@ class TestBigQueryTableExistenceAsyncSensor:
"set `deferrable` attribute to `True` instead"
)

def test_big_query_table_existence_sensor_async(self):
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
def test_big_query_table_existence_sensor_async(self, mock_hook):
"""
Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired
when the BigQueryTableExistenceAsyncSensor is executed.
Expand All @@ -240,6 +280,7 @@ def test_big_query_table_existence_sensor_async(self):
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
mock_hook.return_value.table_exists.return_value = False
with pytest.raises(TaskDeferred) as exc:
task.execute(context={})
assert isinstance(
Expand Down Expand Up @@ -293,7 +334,8 @@ class TestBigQueryTableExistencePartitionAsyncSensor:
"set `deferrable` attribute to `True` instead"
)

def test_big_query_table_existence_partition_sensor_async(self):
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
def test_big_query_table_existence_partition_sensor_async(self, mock_hook):
"""
Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired
when the BigQueryTableExistencePartitionAsyncSensor is executed.
Expand All @@ -306,8 +348,9 @@ def test_big_query_table_existence_partition_sensor_async(self):
table_id=TEST_TABLE_ID,
partition_id=TEST_PARTITION_ID,
)
mock_hook.return_value.table_partition_exists.return_value = False
with pytest.raises(TaskDeferred) as exc:
task.execute(context={})
task.execute(mock.MagicMock())
assert isinstance(
exc.value.trigger, BigQueryTablePartitionExistenceTrigger
), "Trigger is not a BigQueryTablePartitionExistenceTrigger"
Expand Down

0 comments on commit b0a40bb

Please sign in to comment.