Skip to content

Commit

Permalink
Fix BigQueryGetDataOperator's query job bugs in deferrable mode (#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed May 22, 2023
1 parent 0e8bff9 commit 0d6e626
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
14 changes: 9 additions & 5 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -802,7 +802,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param project_id: (Optional) The name of the project where the data
will be returned from. (templated)
will be returned from. If None, it will be derived from the hook's project ID. (templated)
:param max_results: The maximum number of records (rows) to be fetched
from the table. (templated)
:param selected_fields: List of fields to return (comma-separated). If
Expand Down Expand Up @@ -872,7 +872,7 @@ def _submit_job(
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
get_query = self.generate_query()
get_query = self.generate_query(hook=hook)
configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}}
"""Submit a new job and get the job id for polling the status using Triggerer."""
return hook.insert_job(
Expand All @@ -883,17 +883,21 @@ def _submit_job(
nowait=True,
)

def generate_query(self) -> str:
def generate_query(self, hook: BigQueryHook) -> str:
"""
Generate a select query if selected fields are given or with *
for the given dataset and table id
:param hook BigQuery Hook
"""
query = "select "
if self.selected_fields:
query += self.selected_fields
else:
query += "*"
query += f" from `{self.project_id}.{self.dataset_id}.{self.table_id}` limit {self.max_results}"
query += (
f" from `{self.project_id or hook.project_id}.{self.dataset_id}"
f".{self.table_id}` limit {self.max_results}"
)
return query

def execute(self, context: Context):
Expand All @@ -906,7 +910,7 @@ def execute(self, context: Context):
if not self.deferrable:
self.log.info(
"Fetching Data from %s.%s.%s max results: %s",
self.project_id,
self.project_id or hook.project_id,
self.dataset_id,
self.table_id,
self.max_results,
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/triggers/bigquery.py
Expand Up @@ -187,6 +187,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"as_dict": self.as_dict,
},
)

Expand Down
32 changes: 32 additions & 0 deletions tests/providers/google/cloud/operators/test_bigquery.py
Expand Up @@ -814,6 +814,38 @@ def test_execute(self, mock_hook, as_dict):
location=TEST_DATASET_LOCATION,
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_generate_query__with_project_id(self, mock_hook):
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
project_id=TEST_GCP_PROJECT_ID,
max_results=100,
use_legacy_sql=False,
)
assert (
operator.generate_query(hook=mock_hook) == f"select * from `{TEST_GCP_PROJECT_ID}."
f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100"
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_generate_query__without_project_id(self, mock_hook):
hook_project_id = mock_hook.project_id
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
max_results=100,
use_legacy_sql=False,
)
assert (
operator.generate_query(hook=mock_hook) == f"select * from `{hook_project_id}."
f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100"
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_get_data_operator_async_with_selected_fields(
self, mock_hook, create_task_instance_of_operator
Expand Down
1 change: 1 addition & 0 deletions tests/providers/google/cloud/triggers/test_bigquery.py
Expand Up @@ -224,6 +224,7 @@ def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):
classpath, kwargs = get_data_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger"
assert kwargs == {
"as_dict": False,
"conn_id": TEST_CONN_ID,
"job_id": TEST_JOB_ID,
"dataset_id": TEST_DATASET_ID,
Expand Down

0 comments on commit 0d6e626

Please sign in to comment.