diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 729d629c313b6..2b252f3c3f43d 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -20,6 +20,7 @@ from __future__ import annotations +import asyncio import json import logging import re @@ -3242,16 +3243,58 @@ async def get_job_instance( session=cast(Session, session), ) - async def get_job_status(self, job_id: str | None, project_id: str | None = None) -> dict[str, str]: - async with ClientSession() as s: - job_client = await self.get_job_instance(project_id, job_id, s) - job = await job_client.get_job() - status = job.get("status", {}) - if status["state"] == "DONE": - if "errorResult" in status: - return {"status": "error", "message": status["errorResult"]["message"]} - return {"status": "success", "message": "Job completed"} - return {"status": status["state"].lower(), "message": "Job running"} + async def _get_job( + self, job_id: str | None, project_id: str | None = None, location: str | None = None + ) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob: + """ + Get BigQuery job by its ID, project ID and location. + + WARNING. + This is a temporary workaround for issues below, and it's not intended to be used elsewhere! + https://github.com/apache/airflow/issues/35833 + https://github.com/talkiq/gcloud-aio/issues/584 + + This method was developed, because neither the `google-cloud-bigquery` nor the `gcloud-aio-bigquery` + provides asynchronous access to a BigQuery jobs with location parameter. That's why this method wraps + synchronous client call with the event loop's run_in_executor() method. + + This workaround must be deleted along with the method _get_job_sync() and replaced by more robust and + cleaner solution in one of two cases: + 1. The `google-cloud-bigquery` library provides async client with get_job method, that supports + optional parameter `location` + 2. The `gcloud-aio-bigquery` library supports the `location` parameter in get_job() method. + """ + loop = asyncio.get_event_loop() + job = await loop.run_in_executor(None, self._get_job_sync, job_id, project_id, location) + return job + + def _get_job_sync(self, job_id, project_id, location): + """ + Get BigQuery job by its ID, project ID and location synchronously. + + WARNING + This is a temporary workaround for issues below, and it's not intended to be used elsewhere! + https://github.com/apache/airflow/issues/35833 + https://github.com/talkiq/gcloud-aio/issues/584 + + This workaround must be deleted along with the method _get_job() and replaced by more robust and + cleaner solution in one of two cases: + 1. The `google-cloud-bigquery` library provides async client with get_job method, that supports + optional parameter `location` + 2. The `gcloud-aio-bigquery` library supports the `location` parameter in get_job() method. + """ + hook = BigQueryHook(**self._hook_kwargs) + return hook.get_job(job_id=job_id, project_id=project_id, location=location) + + async def get_job_status( + self, job_id: str | None, project_id: str | None = None, location: str | None = None + ) -> dict[str, str]: + job = await self._get_job(job_id=job_id, project_id=project_id, location=location) + if job.state == "DONE": + if job.error_result: + return {"status": "error", "message": job.error_result["message"]} + return {"status": "success", "message": "Job completed"} + return {"status": str(job.state).lower(), "message": "Job running"} async def get_job_output( self, diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index fc90bbeed92ba..b391a1508c764 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -313,6 +313,7 @@ def execute(self, context: Context): conn_id=self.gcp_conn_id, job_id=job.job_id, project_id=hook.project_id, + location=self.location or hook.location, poll_interval=self.poll_interval, impersonation_chain=self.impersonation_chain, ), @@ -438,6 +439,7 @@ def execute(self, context: Context) -> None: # type: ignore[override] conn_id=self.gcp_conn_id, job_id=job.job_id, project_id=hook.project_id, + location=self.location or hook.location, sql=self.sql, pass_value=self.pass_value, tolerance=self.tol, @@ -594,6 +596,7 @@ def execute(self, context: Context): second_job_id=job_2.job_id, project_id=hook.project_id, table=self.table, + location=self.location or hook.location, metrics_thresholds=self.metrics_thresholds, date_filter_column=self.date_filter_column, days_back=self.days_back, @@ -1068,6 +1071,7 @@ def execute(self, context: Context): dataset_id=self.dataset_id, table_id=self.table_id, project_id=self.job_project_id or hook.project_id, + location=self.location or hook.location, poll_interval=self.poll_interval, as_dict=self.as_dict, impersonation_chain=self.impersonation_chain, @@ -2876,6 +2880,7 @@ def execute(self, context: Any): conn_id=self.gcp_conn_id, job_id=self.job_id, project_id=self.project_id, + location=self.location or hook.location, poll_interval=self.poll_interval, impersonation_chain=self.impersonation_chain, ), diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 3ede4db32ff81..aeb7f46f6e9b0 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -261,6 +261,7 @@ def execute(self, context: Context): conn_id=self.gcp_conn_id, job_id=self._job_id, project_id=self.project_id or self.hook.project_id, + location=self.location or self.hook.location, impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 9d8ce53f4c1b2..4e1d6e091975f 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -435,6 +435,7 @@ def execute(self, context: Context): conn_id=self.gcp_conn_id, job_id=self.job_id, project_id=self.project_id or self.hook.project_id, + location=self.location or self.hook.location, impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index 3d73e91e989e8..bc9e812d1b28c 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -33,6 +33,7 @@ class BigQueryInsertJobTrigger(BaseTrigger): :param conn_id: Reference to google cloud connection id :param job_id: The ID of the job. It will be suffixed with hash of job configuration :param project_id: Google Cloud Project where the job is running + :param location: The dataset location. :param dataset_id: The dataset ID of the requested table. (templated) :param table_id: The table ID of the requested table. (templated) :param poll_interval: polling period in seconds to check for the status. (templated) @@ -51,6 +52,7 @@ def __init__( conn_id: str, job_id: str | None, project_id: str | None, + location: str | None, dataset_id: str | None = None, table_id: str | None = None, poll_interval: float = 4.0, @@ -63,6 +65,7 @@ def __init__( self._job_conn = None self.dataset_id = dataset_id self.project_id = project_id + self.location = location self.table_id = table_id self.poll_interval = poll_interval self.impersonation_chain = impersonation_chain @@ -76,6 +79,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "job_id": self.job_id, "dataset_id": self.dataset_id, "project_id": self.project_id, + "location": self.location, "table_id": self.table_id, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, @@ -87,7 +91,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] hook = self._get_async_hook() try: while True: - job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + job_status = await hook.get_job_status( + job_id=self.job_id, project_id=self.project_id, location=self.location + ) if job_status["status"] == "success": yield TriggerEvent( { @@ -127,6 +133,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "job_id": self.job_id, "dataset_id": self.dataset_id, "project_id": self.project_id, + "location": self.location, "table_id": self.table_id, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, @@ -201,6 +208,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "job_id": self.job_id, "dataset_id": self.dataset_id, "project_id": self.project_id, + "location": self.location, "table_id": self.table_id, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, @@ -253,6 +261,7 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): :param dataset_id: The dataset ID of the requested table. (templated) :param table: table name :param metrics_thresholds: dictionary of ratios indexed by metrics + :param location: The dataset location. :param date_filter_column: column name. (templated) :param days_back: number of days between ds and the ds we want to check against. (templated) :param ratio_formula: ration formula. (templated) @@ -277,6 +286,7 @@ def __init__( project_id: str | None, table: str, metrics_thresholds: dict[str, int], + location: str | None = None, date_filter_column: str | None = "ds", days_back: SupportsAbs[int] = -7, ratio_formula: str = "max_over_min", @@ -290,6 +300,7 @@ def __init__( conn_id=conn_id, job_id=first_job_id, project_id=project_id, + location=location, dataset_id=dataset_id, table_id=table_id, poll_interval=poll_interval, @@ -317,6 +328,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "project_id": self.project_id, "table": self.table, "metrics_thresholds": self.metrics_thresholds, + "location": self.location, "date_filter_column": self.date_filter_column, "days_back": self.days_back, "ratio_formula": self.ratio_formula, @@ -414,6 +426,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): :param tolerance: certain metrics for tolerance. (templated) :param dataset_id: The dataset ID of the requested table. (templated) :param table_id: The table ID of the requested table. (templated) + :param location: The dataset location :param poll_interval: polling period in seconds to check for the status. (templated) :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -435,6 +448,7 @@ def __init__( tolerance: Any = None, dataset_id: str | None = None, table_id: str | None = None, + location: str | None = None, poll_interval: float = 4.0, impersonation_chain: str | Sequence[str] | None = None, ): @@ -444,6 +458,7 @@ def __init__( project_id=project_id, dataset_id=dataset_id, table_id=table_id, + location=location, poll_interval=poll_interval, impersonation_chain=impersonation_chain, ) @@ -464,6 +479,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "sql": self.sql, "table_id": self.table_id, "tolerance": self.tolerance, + "location": self.location, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, }, diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 27db487b61d2c..5ca34b276f21c 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -2155,23 +2155,18 @@ async def test_get_job_instance(self, mock_session, mock_auth_default): assert isinstance(result, Job) @pytest.mark.parametrize( - "job_status, expected", + "job_state, error_result, expected", [ - ({"status": {"state": "DONE"}}, {"status": "success", "message": "Job completed"}), - ( - {"status": {"state": "DONE", "errorResult": {"message": "Timeout"}}}, - {"status": "error", "message": "Timeout"}, - ), - ({"status": {"state": "running"}}, {"status": "running", "message": "Job running"}), + ("DONE", None, {"status": "success", "message": "Job completed"}), + ("DONE", {"message": "Timeout"}, {"status": "error", "message": "Timeout"}), + ("RUNNING", None, {"status": "running", "message": "Job running"}), ], ) @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") - async def test_get_job_status(self, mock_job_instance, job_status, expected): + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job") + async def test_get_job_status(self, mock_get_job, job_state, error_result, expected): hook = BigQueryAsyncHook() - mock_job_client = AsyncMock(Job) - mock_job_instance.return_value = mock_job_client - mock_job_instance.return_value.get_job.return_value = job_status + mock_get_job.return_value = mock.MagicMock(state=job_state, error_result=error_result) resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID) assert resp == expected diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py index ea2e478d04cfb..ed4861ca76e52 100644 --- a/tests/providers/google/cloud/triggers/test_bigquery.py +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -24,7 +24,7 @@ import pytest from aiohttp import ClientResponseError, RequestInfo -from gcloud.aio.bigquery import Job, Table +from gcloud.aio.bigquery import Table from multidict import CIMultiDict from yarl import URL @@ -48,6 +48,7 @@ TEST_GCP_PROJECT_ID = "test-project" TEST_DATASET_ID = "bq_dataset" TEST_TABLE_ID = "bq_table" +TEST_LOCATION = "US" POLLING_PERIOD_SECONDS = 4.0 TEST_SQL_QUERY = "SELECT count(*) from Any" TEST_PASS_VALUE = 2 @@ -73,6 +74,7 @@ def insert_job_trigger(): project_id=TEST_GCP_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, + location=TEST_LOCATION, poll_interval=POLLING_PERIOD_SECONDS, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) @@ -86,6 +88,7 @@ def get_data_trigger(): project_id=TEST_GCP_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, + location=None, poll_interval=POLLING_PERIOD_SECONDS, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) @@ -132,6 +135,7 @@ def check_trigger(): project_id=TEST_GCP_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, + location=None, poll_interval=POLLING_PERIOD_SECONDS, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) @@ -166,6 +170,7 @@ def test_serialization(self, insert_job_trigger): "project_id": TEST_GCP_PROJECT_ID, "dataset_id": TEST_DATASET_ID, "table_id": TEST_TABLE_ID, + "location": TEST_LOCATION, "poll_interval": POLLING_PERIOD_SECONDS, "impersonation_chain": TEST_IMPERSONATION_CHAIN, } @@ -185,13 +190,11 @@ async def test_bigquery_insert_job_op_trigger_success(self, mock_job_status, ins ) @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") - async def test_bigquery_insert_job_trigger_running(self, mock_job_instance, caplog, insert_job_trigger): + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job") + async def test_bigquery_insert_job_trigger_running(self, mock_get_job, caplog, insert_job_trigger): """Test that BigQuery Triggers do not fire while a query is still running.""" - mock_job_client = AsyncMock(Job) - mock_job_instance.return_value = mock_job_client - mock_job_instance.return_value.get_job.return_value = {"status": {"state": "running"}} + mock_get_job.return_value = mock.MagicMock(state="RUNNING") caplog.set_level(logging.INFO) task = asyncio.create_task(insert_job_trigger.run().__anext__()) @@ -245,17 +248,16 @@ def test_bigquery_get_data_trigger_serialization(self, get_data_trigger): "dataset_id": TEST_DATASET_ID, "project_id": TEST_GCP_PROJECT_ID, "table_id": TEST_TABLE_ID, + "location": None, "poll_interval": POLLING_PERIOD_SECONDS, } @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") - async def test_bigquery_get_data_trigger_running(self, mock_job_instance, caplog, get_data_trigger): + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job") + async def test_bigquery_get_data_trigger_running(self, mock_get_job, caplog, get_data_trigger): """Test that BigQuery Triggers do not fire while a query is still running.""" - mock_job_client = AsyncMock(Job) - mock_job_instance.return_value = mock_job_client - mock_job_instance.return_value.get_job.return_value = {"status": {"state": "running"}} + mock_get_job.return_value = mock.MagicMock(state="running") caplog.set_level(logging.INFO) task = asyncio.create_task(get_data_trigger.run().__anext__()) @@ -348,13 +350,11 @@ async def test_bigquery_get_data_trigger_success_with_data( class TestBigQueryCheckTrigger: @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") - async def test_bigquery_check_trigger_running(self, mock_job_instance, caplog, check_trigger): + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job") + async def test_bigquery_check_trigger_running(self, mock_get_job, caplog, check_trigger): """Test that BigQuery Triggers do not fire while a query is still running.""" - mock_job_client = AsyncMock(Job) - mock_job_instance.return_value = mock_job_client - mock_job_instance.return_value.get_job.return_value = {"status": {"state": "running"}} + mock_get_job.return_value = mock.MagicMock(state="running") task = asyncio.create_task(check_trigger.run().__anext__()) await asyncio.sleep(0.5) @@ -406,6 +406,7 @@ def test_check_trigger_serialization(self, check_trigger): "dataset_id": TEST_DATASET_ID, "project_id": TEST_GCP_PROJECT_ID, "table_id": TEST_TABLE_ID, + "location": None, "poll_interval": POLLING_PERIOD_SECONDS, } @@ -487,6 +488,7 @@ def test_interval_check_trigger_serialization(self, interval_check_trigger): "second_job_id": TEST_SECOND_JOB_ID, "project_id": TEST_GCP_PROJECT_ID, "table": TEST_TABLE_ID, + "location": None, "metrics_thresholds": TEST_METRIC_THRESHOLDS, "date_filter_column": TEST_DATE_FILTER_COLUMN, "days_back": TEST_DAYS_BACK, @@ -578,6 +580,7 @@ def test_bigquery_value_check_op_trigger_serialization(self, value_check_trigger "job_id": TEST_JOB_ID, "dataset_id": TEST_DATASET_ID, "project_id": TEST_GCP_PROJECT_ID, + "location": None, "sql": TEST_SQL_QUERY, "table_id": TEST_TABLE_ID, "tolerance": TEST_TOLERANCE,