Skip to content

Commit

Permalink
improvement: introduce proejct_id in BigQueryIntervalCheckOperator (#…
Browse files Browse the repository at this point in the history
…34573)

* improvement: introduce proejct_id in BigQueryIntervalCheckOperator

* Update bigquery.py

---------

Co-authored-by: Hussein Awala <hussein@awala.fr>
  • Loading branch information
Zhenye-Na and hussein-awala committed Sep 24, 2023
1 parent 2c53345 commit 6a03870
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
5 changes: 4 additions & 1 deletion airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -505,6 +505,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat
:param deferrable: Run operator in the deferrable mode
:param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job.
Defaults to 4 seconds.
:param project_id: a string represents the BigQuery projectId
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -532,6 +533,7 @@ def __init__(
labels: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: float = 4.0,
project_id: str | None = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -547,6 +549,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels
self.project_id = project_id
self.deferrable = deferrable
self.poll_interval = poll_interval

Expand All @@ -560,7 +563,7 @@ def _submit_job(
configuration = {"query": {"query": sql, "useLegacySql": self.use_legacy_sql}}
return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
project_id=self.project_id or hook.project_id,
location=self.location,
job_id=job_id,
nowait=True,
Expand Down
74 changes: 74 additions & 0 deletions tests/providers/google/cloud/operators/test_bigquery.py
Expand Up @@ -1768,6 +1768,80 @@ def test_bigquery_interval_check_operator_async(self, mock_hook, create_task_ins
exc.value.trigger, BigQueryIntervalCheckTrigger
), "Trigger is not a BigQueryIntervalCheckTrigger"

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_interval_check_operator_with_project_id(
self, mock_hook, create_task_instance_of_operator
):
"""
Test BigQueryIntervalCheckOperator with a specified project_id.
Ensure that the bq_project_id is passed correctly when submitting the job.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

project_id = "test-project-id"
ti = create_task_instance_of_operator(
BigQueryIntervalCheckOperator,
dag_id="dag_id",
task_id="bq_interval_check_operator_with_project_id",
table="test_table",
metrics_thresholds={"COUNT(*)": 1.5},
location=TEST_DATASET_LOCATION,
deferrable=True,
project_id=project_id,
)

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

with pytest.raises(TaskDeferred):
ti.task.execute(MagicMock())

mock_hook.return_value.insert_job.assert_called_with(
configuration=mock.ANY,
project_id=project_id,
location=TEST_DATASET_LOCATION,
job_id=mock.ANY,
nowait=True,
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_interval_check_operator_without_project_id(
self, mock_hook, create_task_instance_of_operator
):
"""
Test BigQueryIntervalCheckOperator without a specified project_id.
Ensure that the project_id falls back to the hook.project_id as previously implemented.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

project_id = "test-project-id"
ti = create_task_instance_of_operator(
BigQueryIntervalCheckOperator,
dag_id="dag_id",
task_id="bq_interval_check_operator_without_project_id",
table="test_table",
metrics_thresholds={"COUNT(*)": 1.5},
location=TEST_DATASET_LOCATION,
deferrable=True,
)

mock_hook.return_value.project_id = project_id
mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

with pytest.raises(TaskDeferred):
ti.task.execute(MagicMock())

mock_hook.return_value.insert_job.assert_called_with(
configuration=mock.ANY,
project_id=mock_hook.return_value.project_id,
location=TEST_DATASET_LOCATION,
job_id=mock.ANY,
nowait=True,
)


class TestBigQueryCheckOperator:
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator.execute")
Expand Down

0 comments on commit 6a03870

Please sign in to comment.