diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index fba30a5fca99d..e690fec56d07a 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -110,20 +110,24 @@ def poke(self, context: Context) -> bool: def execute(self, context: Context) -> None: """Airflow runs this method on the worker and defers using the trigger.""" - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=BigQueryTableExistenceTrigger( - dataset_id=self.dataset_id, - table_id=self.table_id, - project_id=self.project_id, - poll_interval=self.poke_interval, - gcp_conn_id=self.gcp_conn_id, - hook_params={ - "impersonation_chain": self.impersonation_chain, - }, - ), - method_name="execute_complete", - ) + if not self.deferrable: + super().execute(context) + else: + if not self.poke(context=context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=BigQueryTableExistenceTrigger( + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=self.project_id, + poll_interval=self.poke_interval, + gcp_conn_id=self.gcp_conn_id, + hook_params={ + "impersonation_chain": self.impersonation_chain, + }, + ), + method_name="execute_complete", + ) def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str: """ @@ -218,21 +222,22 @@ def execute(self, context: Context) -> None: if not self.deferrable: super().execute(context) else: - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=BigQueryTablePartitionExistenceTrigger( - dataset_id=self.dataset_id, - table_id=self.table_id, - project_id=self.project_id, - partition_id=self.partition_id, - poll_interval=self.poke_interval, - gcp_conn_id=self.gcp_conn_id, - hook_params={ - "impersonation_chain": self.impersonation_chain, - }, - ), - method_name="execute_complete", - ) + if not self.poke(context=context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=BigQueryTablePartitionExistenceTrigger( + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=self.project_id, + partition_id=self.partition_id, + poll_interval=self.poke_interval, + gcp_conn_id=self.gcp_conn_id, + hook_params={ + "impersonation_chain": self.impersonation_chain, + }, + ), + method_name="execute_complete", + ) def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str: """ diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py b/tests/providers/google/cloud/sensors/test_bigquery.py index b699f7579773f..4805c6f3c1333 100644 --- a/tests/providers/google/cloud/sensors/test_bigquery.py +++ b/tests/providers/google/cloud/sensors/test_bigquery.py @@ -64,7 +64,24 @@ def test_passing_arguments_to_hook(self, mock_hook): project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID ) - def test_execute_defered(self): + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook") + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor.defer") + def test_table_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook): + task = BigQueryTableExistenceSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + ) + mock_hook.return_value.table_exists.return_value = True + task.execute(mock.MagicMock()) + assert not mock_defer.called + + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook") + def test_execute_deferred(self, mock_hook): """ Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired when the BigQueryTableExistenceAsyncSensor is executed. @@ -76,13 +93,14 @@ def test_execute_defered(self): table_id=TEST_TABLE_ID, deferrable=True, ) + mock_hook.return_value.table_exists.return_value = False with pytest.raises(TaskDeferred) as exc: - task.execute(context={}) + task.execute(mock.MagicMock()) assert isinstance( exc.value.trigger, BigQueryTableExistenceTrigger ), "Trigger is not a BigQueryTableExistenceTrigger" - def test_excute_defered_failure(self): + def test_execute_deferred_failure(self): """Tests that an AirflowException is raised in case of error event""" task = BigQueryTableExistenceSensor( task_id="task-id", @@ -148,7 +166,9 @@ def test_passing_arguments_to_hook(self, mock_hook): partition_id=TEST_PARTITION_ID, ) - def test_execute_with_deferrable_mode(self): + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook") + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryTablePartitionExistenceSensor.defer") + def test_table_partition_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook): """ Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired when the BigQueryTablePartitionExistenceSensor is executed and deferrable is set to True. @@ -161,6 +181,25 @@ def test_execute_with_deferrable_mode(self): partition_id=TEST_PARTITION_ID, deferrable=True, ) + mock_hook.return_value.table_partition_exists.return_value = True + task.execute(mock.MagicMock()) + assert not mock_defer.called + + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook") + def test_execute_with_deferrable_mode(self, mock_hook): + """ + Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired + when the BigQueryTablePartitionExistenceSensor is executed and deferrable is set to True. + """ + task = BigQueryTablePartitionExistenceSensor( + task_id="test_task_id", + project_id=TEST_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + partition_id=TEST_PARTITION_ID, + deferrable=True, + ) + mock_hook.return_value.table_partition_exists.return_value = False with pytest.raises(TaskDeferred) as exc: task.execute(context={}) assert isinstance( @@ -228,7 +267,8 @@ class TestBigQueryTableExistenceAsyncSensor: "set `deferrable` attribute to `True` instead" ) - def test_big_query_table_existence_sensor_async(self): + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook") + def test_big_query_table_existence_sensor_async(self, mock_hook): """ Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired when the BigQueryTableExistenceAsyncSensor is executed. @@ -240,6 +280,7 @@ def test_big_query_table_existence_sensor_async(self): dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, ) + mock_hook.return_value.table_exists.return_value = False with pytest.raises(TaskDeferred) as exc: task.execute(context={}) assert isinstance( @@ -293,7 +334,8 @@ class TestBigQueryTableExistencePartitionAsyncSensor: "set `deferrable` attribute to `True` instead" ) - def test_big_query_table_existence_partition_sensor_async(self): + @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook") + def test_big_query_table_existence_partition_sensor_async(self, mock_hook): """ Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired when the BigQueryTableExistencePartitionAsyncSensor is executed. @@ -306,8 +348,9 @@ def test_big_query_table_existence_partition_sensor_async(self): table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, ) + mock_hook.return_value.table_partition_exists.return_value = False with pytest.raises(TaskDeferred) as exc: - task.execute(context={}) + task.execute(mock.MagicMock()) assert isinstance( exc.value.trigger, BigQueryTablePartitionExistenceTrigger ), "Trigger is not a BigQueryTablePartitionExistenceTrigger"