Skip to content

Commit

Permalink
fix: BigQueryCheckOperator skipped value and error check in deferrabl…
Browse files Browse the repository at this point in the history
…e mode (#38408)

Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda committed Apr 19, 2024
1 parent fd8a057 commit eee17f0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 21 deletions.
30 changes: 21 additions & 9 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,26 @@ def execute(self, context: Context):
),
method_name="execute_complete",
)
self._handle_job_error(job)
# job.result() returns a RowIterator. Mypy expects an instance of SupportsNext[Any] for
# the next() call which the RowIterator does not resemble to. Hence, ignore the arg-type error.
# Row passed to _validate_records is a collection of values only, without column names.
self._validate_records(next(iter(job.result()), [])) # type: ignore[arg-type]
self.log.info("Current state of job %s is %s", job.job_id, job.state)

@staticmethod
def _handle_job_error(job: BigQueryJob | UnknownJob) -> None:
if job.error_result:
raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}")

def _validate_records(self, records) -> None:
if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")
elif not all(records):
self._raise_exception( # type: ignore[attr-defined]
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""Act as a callback for when the trigger fires.
Expand All @@ -333,13 +351,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])

records = event["records"]
if not records:
raise AirflowException("The query returned empty results")
elif not all(records):
self._raise_exception( # type: ignore[attr-defined]
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
)
self._validate_records(event["records"])
self.log.info("Record: %s", event["records"])
self.log.info("Success.")

Expand Down Expand Up @@ -454,8 +466,8 @@ def execute(self, context: Context) -> None: # type: ignore[override]
self._handle_job_error(job)
# job.result() returns a RowIterator. Mypy expects an instance of SupportsNext[Any] for
# the next() call which the RowIterator does not resemble to. Hence, ignore the arg-type error.
records = next(job.result()) # type: ignore[arg-type]
self.check_value(records) # type: ignore[attr-defined]
# Row passed to check_value is a collection of values only, without column names.
self.check_value(next(iter(job.result()), [])) # type: ignore[arg-type]
self.log.info("Current state of job %s is %s", job.job_id, job.state)

@staticmethod
Expand Down
77 changes: 65 additions & 12 deletions tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,17 +2046,19 @@ def test_bigquery_interval_check_operator_without_project_id(

class TestBigQueryCheckOperator:
@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator.execute")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator._validate_records")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator.defer")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_check_operator_async_finish_before_deferred(
self, mock_hook, mock_defer, mock_execute, create_task_instance_of_operator
self, mock_hook, mock_defer, mock_validate_records, create_task_instance_of_operator
):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
mocked_job = MagicMock(job_id=real_job_id, error_result=False)
mocked_job.result.return_value = iter([(1, 2, 3)]) # mock rows generator
mock_hook.return_value.insert_job.return_value = mocked_job
mock_hook.return_value.insert_job.return_value.running.return_value = False

ti = create_task_instance_of_operator(
Expand All @@ -2069,8 +2071,34 @@ def test_bigquery_check_operator_async_finish_before_deferred(
)

ti.task.execute(MagicMock())
assert not mock_defer.called
assert mock_execute.called
mock_defer.assert_not_called()
mock_validate_records.assert_called_once_with((1, 2, 3))

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_check_operator_async_finish_with_error_before_deferred(
self, mock_hook, create_task_instance_of_operator
):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=True)
mock_hook.return_value.insert_job.return_value.running.return_value = False

ti = create_task_instance_of_operator(
BigQueryCheckOperator,
dag_id="dag_id",
task_id="bq_check_operator_job",
sql="SELECT * FROM any",
location=TEST_DATASET_LOCATION,
deferrable=True,
)

with pytest.raises(AirflowException) as exc:
ti.task.execute(MagicMock())

assert str(exc.value) == f"BigQuery job {real_job_id} failed: True"

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
Expand Down Expand Up @@ -2124,13 +2152,9 @@ def test_bigquery_check_op_execute_complete_with_no_records(self):
deferrable=True,
)

with pytest.raises(AirflowException) as exc:
with pytest.raises(AirflowException, match="The following query returned zero rows:"):
operator.execute_complete(context=None, event={"status": "success", "records": None})

expected_exception_msg = "The query returned empty results"

assert str(exc.value) == expected_exception_msg

def test_bigquery_check_op_execute_complete_with_non_boolean_records(self):
"""Executing a sql which returns a non-boolean value should raise exception"""

Expand Down Expand Up @@ -2204,7 +2228,9 @@ def test_bigquery_value_check_operator_async_finish_before_deferred(
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
mocked_job = MagicMock(job_id=real_job_id, error_result=False)
mocked_job.result.return_value = iter([(1, 2, 3)]) # mock rows generator
mock_hook.return_value.insert_job.return_value = mocked_job
mock_hook.return_value.insert_job.return_value.running.return_value = False

ti = create_task_instance_of_operator(
Expand All @@ -2219,7 +2245,34 @@ def test_bigquery_value_check_operator_async_finish_before_deferred(

ti.task.execute(MagicMock())
assert not mock_defer.called
assert mock_check_value.called
mock_check_value.assert_called_once_with((1, 2, 3))

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_value_check_operator_async_finish_with_error_before_deferred(
self, mock_hook, create_task_instance_of_operator
):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=True)
mock_hook.return_value.insert_job.return_value.running.return_value = False

ti = create_task_instance_of_operator(
BigQueryValueCheckOperator,
dag_id="dag_id",
task_id="check_value",
sql="SELECT COUNT(*) FROM Any",
pass_value=2,
use_legacy_sql=True,
deferrable=True,
)

with pytest.raises(AirflowException) as exc:
ti.task.execute(MagicMock())

assert str(exc.value) == f"BigQuery job {real_job_id} failed: True"

@pytest.mark.parametrize(
"kwargs, expected",
Expand Down

0 comments on commit eee17f0

Please sign in to comment.