Skip to content

Commit

Permalink
Add optional 'location' parameter to the BigQueryInsertJobTrigger (#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Feb 12, 2024
1 parent 028fbdf commit d43c804
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 39 deletions.
63 changes: 53 additions & 10 deletions airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -20,6 +20,7 @@

from __future__ import annotations

import asyncio
import json
import logging
import re
Expand Down Expand Up @@ -3242,16 +3243,58 @@ async def get_job_instance(
session=cast(Session, session),
)

async def get_job_status(self, job_id: str | None, project_id: str | None = None) -> dict[str, str]:
async with ClientSession() as s:
job_client = await self.get_job_instance(project_id, job_id, s)
job = await job_client.get_job()
status = job.get("status", {})
if status["state"] == "DONE":
if "errorResult" in status:
return {"status": "error", "message": status["errorResult"]["message"]}
return {"status": "success", "message": "Job completed"}
return {"status": status["state"].lower(), "message": "Job running"}
async def _get_job(
self, job_id: str | None, project_id: str | None = None, location: str | None = None
) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
"""
Get BigQuery job by its ID, project ID and location.
WARNING.
This is a temporary workaround for issues below, and it's not intended to be used elsewhere!
https://github.com/apache/airflow/issues/35833
https://github.com/talkiq/gcloud-aio/issues/584
This method was developed, because neither the `google-cloud-bigquery` nor the `gcloud-aio-bigquery`
provides asynchronous access to a BigQuery jobs with location parameter. That's why this method wraps
synchronous client call with the event loop's run_in_executor() method.
This workaround must be deleted along with the method _get_job_sync() and replaced by more robust and
cleaner solution in one of two cases:
1. The `google-cloud-bigquery` library provides async client with get_job method, that supports
optional parameter `location`
2. The `gcloud-aio-bigquery` library supports the `location` parameter in get_job() method.
"""
loop = asyncio.get_event_loop()
job = await loop.run_in_executor(None, self._get_job_sync, job_id, project_id, location)
return job

def _get_job_sync(self, job_id, project_id, location):
"""
Get BigQuery job by its ID, project ID and location synchronously.
WARNING
This is a temporary workaround for issues below, and it's not intended to be used elsewhere!
https://github.com/apache/airflow/issues/35833
https://github.com/talkiq/gcloud-aio/issues/584
This workaround must be deleted along with the method _get_job() and replaced by more robust and
cleaner solution in one of two cases:
1. The `google-cloud-bigquery` library provides async client with get_job method, that supports
optional parameter `location`
2. The `gcloud-aio-bigquery` library supports the `location` parameter in get_job() method.
"""
hook = BigQueryHook(**self._hook_kwargs)
return hook.get_job(job_id=job_id, project_id=project_id, location=location)

async def get_job_status(
self, job_id: str | None, project_id: str | None = None, location: str | None = None
) -> dict[str, str]:
job = await self._get_job(job_id=job_id, project_id=project_id, location=location)
if job.state == "DONE":
if job.error_result:
return {"status": "error", "message": job.error_result["message"]}
return {"status": "success", "message": "Job completed"}
return {"status": str(job.state).lower(), "message": "Job running"}

async def get_job_output(
self,
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -313,6 +313,7 @@ def execute(self, context: Context):
conn_id=self.gcp_conn_id,
job_id=job.job_id,
project_id=hook.project_id,
location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
Expand Down Expand Up @@ -438,6 +439,7 @@ def execute(self, context: Context) -> None: # type: ignore[override]
conn_id=self.gcp_conn_id,
job_id=job.job_id,
project_id=hook.project_id,
location=self.location or hook.location,
sql=self.sql,
pass_value=self.pass_value,
tolerance=self.tol,
Expand Down Expand Up @@ -594,6 +596,7 @@ def execute(self, context: Context):
second_job_id=job_2.job_id,
project_id=hook.project_id,
table=self.table,
location=self.location or hook.location,
metrics_thresholds=self.metrics_thresholds,
date_filter_column=self.date_filter_column,
days_back=self.days_back,
Expand Down Expand Up @@ -1068,6 +1071,7 @@ def execute(self, context: Context):
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.job_project_id or hook.project_id,
location=self.location or hook.location,
poll_interval=self.poll_interval,
as_dict=self.as_dict,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -2876,6 +2880,7 @@ def execute(self, context: Any):
conn_id=self.gcp_conn_id,
job_id=self.job_id,
project_id=self.project_id,
location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
Expand Down
Expand Up @@ -261,6 +261,7 @@ def execute(self, context: Context):
conn_id=self.gcp_conn_id,
job_id=self._job_id,
project_id=self.project_id or self.hook.project_id,
location=self.location or self.hook.location,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
Expand Down
Expand Up @@ -435,6 +435,7 @@ def execute(self, context: Context):
conn_id=self.gcp_conn_id,
job_id=self.job_id,
project_id=self.project_id or self.hook.project_id,
location=self.location or self.hook.location,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
Expand Down
18 changes: 17 additions & 1 deletion airflow/providers/google/cloud/triggers/bigquery.py
Expand Up @@ -33,6 +33,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
:param conn_id: Reference to google cloud connection id
:param job_id: The ID of the job. It will be suffixed with hash of job configuration
:param project_id: Google Cloud Project where the job is running
:param location: The dataset location.
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param poll_interval: polling period in seconds to check for the status. (templated)
Expand All @@ -51,6 +52,7 @@ def __init__(
conn_id: str,
job_id: str | None,
project_id: str | None,
location: str | None,
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
Expand All @@ -63,6 +65,7 @@ def __init__(
self._job_conn = None
self.dataset_id = dataset_id
self.project_id = project_id
self.location = location
self.table_id = table_id
self.poll_interval = poll_interval
self.impersonation_chain = impersonation_chain
Expand All @@ -76,6 +79,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_id": self.job_id,
"dataset_id": self.dataset_id,
"project_id": self.project_id,
"location": self.location,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
Expand All @@ -87,7 +91,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
hook = self._get_async_hook()
try:
while True:
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
job_status = await hook.get_job_status(
job_id=self.job_id, project_id=self.project_id, location=self.location
)
if job_status["status"] == "success":
yield TriggerEvent(
{
Expand Down Expand Up @@ -127,6 +133,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_id": self.job_id,
"dataset_id": self.dataset_id,
"project_id": self.project_id,
"location": self.location,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
Expand Down Expand Up @@ -201,6 +208,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_id": self.job_id,
"dataset_id": self.dataset_id,
"project_id": self.project_id,
"location": self.location,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
Expand Down Expand Up @@ -253,6 +261,7 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table: table name
:param metrics_thresholds: dictionary of ratios indexed by metrics
:param location: The dataset location.
:param date_filter_column: column name. (templated)
:param days_back: number of days between ds and the ds we want to check against. (templated)
:param ratio_formula: ration formula. (templated)
Expand All @@ -277,6 +286,7 @@ def __init__(
project_id: str | None,
table: str,
metrics_thresholds: dict[str, int],
location: str | None = None,
date_filter_column: str | None = "ds",
days_back: SupportsAbs[int] = -7,
ratio_formula: str = "max_over_min",
Expand All @@ -290,6 +300,7 @@ def __init__(
conn_id=conn_id,
job_id=first_job_id,
project_id=project_id,
location=location,
dataset_id=dataset_id,
table_id=table_id,
poll_interval=poll_interval,
Expand Down Expand Up @@ -317,6 +328,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"table": self.table,
"metrics_thresholds": self.metrics_thresholds,
"location": self.location,
"date_filter_column": self.date_filter_column,
"days_back": self.days_back,
"ratio_formula": self.ratio_formula,
Expand Down Expand Up @@ -414,6 +426,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
:param tolerance: certain metrics for tolerance. (templated)
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param location: The dataset location
:param poll_interval: polling period in seconds to check for the status. (templated)
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand All @@ -435,6 +448,7 @@ def __init__(
tolerance: Any = None,
dataset_id: str | None = None,
table_id: str | None = None,
location: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
):
Expand All @@ -444,6 +458,7 @@ def __init__(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
location=location,
poll_interval=poll_interval,
impersonation_chain=impersonation_chain,
)
Expand All @@ -464,6 +479,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"sql": self.sql,
"table_id": self.table_id,
"tolerance": self.tolerance,
"location": self.location,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
},
Expand Down
19 changes: 7 additions & 12 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Expand Up @@ -2155,23 +2155,18 @@ async def test_get_job_instance(self, mock_session, mock_auth_default):
assert isinstance(result, Job)

@pytest.mark.parametrize(
"job_status, expected",
"job_state, error_result, expected",
[
({"status": {"state": "DONE"}}, {"status": "success", "message": "Job completed"}),
(
{"status": {"state": "DONE", "errorResult": {"message": "Timeout"}}},
{"status": "error", "message": "Timeout"},
),
({"status": {"state": "running"}}, {"status": "running", "message": "Job running"}),
("DONE", None, {"status": "success", "message": "Job completed"}),
("DONE", {"message": "Timeout"}, {"status": "error", "message": "Timeout"}),
("RUNNING", None, {"status": "running", "message": "Job running"}),
],
)
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
async def test_get_job_status(self, mock_job_instance, job_status, expected):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
async def test_get_job_status(self, mock_get_job, job_state, error_result, expected):
hook = BigQueryAsyncHook()
mock_job_client = AsyncMock(Job)
mock_job_instance.return_value = mock_job_client
mock_job_instance.return_value.get_job.return_value = job_status
mock_get_job.return_value = mock.MagicMock(state=job_state, error_result=error_result)
resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
assert resp == expected

Expand Down

0 comments on commit d43c804

Please sign in to comment.