diff --git a/airflow/providers/google/cloud/sensors/dataplex.py b/airflow/providers/google/cloud/sensors/dataplex.py index 887c7e47c6852..e1581fe1dac97 100644 --- a/airflow/providers/google/cloud/sensors/dataplex.py +++ b/airflow/providers/google/cloud/sensors/dataplex.py @@ -167,7 +167,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, fail_on_dq_failure: bool = False, result_timeout: float = 60.0 * 10, - start_sensor_time: float = time.monotonic(), + start_sensor_time: float | None = None, *args, **kwargs, ) -> None: @@ -185,10 +185,9 @@ def __init__( self.result_timeout = result_timeout self.start_sensor_time = start_sensor_time - def execute(self, context: Context) -> None: - super().execute(context) - def _duration(self): + if not self.start_sensor_time: + self.start_sensor_time = time.monotonic() return time.monotonic() - self.start_sensor_time def poke(self, context: Context) -> bool: diff --git a/tests/providers/google/cloud/sensors/test_dataplex.py b/tests/providers/google/cloud/sensors/test_dataplex.py index 3b59fd1ac6361..c9ae358efaf35 100644 --- a/tests/providers/google/cloud/sensors/test_dataplex.py +++ b/tests/providers/google/cloud/sensors/test_dataplex.py @@ -23,6 +23,7 @@ from google.cloud.dataplex_v1.types import DataScanJob from airflow import AirflowException +from airflow.providers.google.cloud.hooks.dataplex import AirflowDataQualityScanResultTimeoutException from airflow.providers.google.cloud.sensors.dataplex import ( DataplexDataQualityJobStatusSensor, DataplexTaskStateSensor, @@ -144,3 +145,45 @@ def test_done(self, mock_hook): ) assert result + + def test_start_sensor_time(self): + sensor = DataplexDataQualityJobStatusSensor( + task_id=TASK_ID, + project_id=PROJECT_ID, + job_id=TEST_JOB_ID, + data_scan_id=TEST_DATA_SCAN_ID, + region=REGION, + api_version=API_VERSION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT, + ) + + assert sensor.start_sensor_time is None + + duration_1 = sensor._duration() + duration_2 = sensor._duration() + + assert sensor.start_sensor_time + assert 0 < duration_1 < duration_2 + + @mock.patch.object(DataplexDataQualityJobStatusSensor, "_duration") + def test_start_sensor_time_timeout(self, mock_duration): + result_timeout = 100 + mock_duration.return_value = result_timeout + 1 + + sensor = DataplexDataQualityJobStatusSensor( + task_id=TASK_ID, + project_id=PROJECT_ID, + job_id=TEST_JOB_ID, + data_scan_id=TEST_DATA_SCAN_ID, + region=REGION, + api_version=API_VERSION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT, + result_timeout=result_timeout, + ) + + with pytest.raises(AirflowDataQualityScanResultTimeoutException): + sensor.poke(context={})