Skip to content

Commit

Permalink
Add use_legacy_sql param to BigQueryGetDataOperator (#31190)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed May 12, 2023
1 parent 2453231 commit d1fe671
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
14 changes: 11 additions & 3 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -819,6 +819,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
Defaults to 4 seconds.
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists
(default: False).
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
"""

template_fields: Sequence[str] = (
Expand All @@ -845,6 +846,7 @@ def __init__(
deferrable: bool = False,
poll_interval: float = 4.0,
as_dict: bool = False,
use_legacy_sql: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -860,14 +862,15 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval
self.as_dict = as_dict
self.use_legacy_sql = use_legacy_sql

def _submit_job(
self,
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
get_query = self.generate_query()
configuration = {"query": {"query": get_query}}
configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}}
"""Submit a new job and get the job id for polling the status using Triggerer."""
return hook.insert_job(
configuration=configuration,
Expand All @@ -887,18 +890,23 @@ def generate_query(self) -> str:
query += self.selected_fields
else:
query += "*"
query += f" from {self.dataset_id}.{self.table_id} limit {self.max_results}"
query += f" from `{self.project_id}.{self.dataset_id}.{self.table_id}` limit {self.max_results}"
return query

def execute(self, context: Context):
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
use_legacy_sql=self.use_legacy_sql,
)

if not self.deferrable:
self.log.info(
"Fetching Data from %s.%s max results: %s", self.dataset_id, self.table_id, self.max_results
"Fetching Data from %s.%s.%s max results: %s",
self.project_id,
self.dataset_id,
self.table_id,
self.max_results,
)
if not self.selected_fields:
schema: dict[str, list] = hook.get_schema(
Expand Down
20 changes: 8 additions & 12 deletions tests/providers/google/cloud/operators/test_bigquery.py
Expand Up @@ -82,6 +82,7 @@
"refreshIntervalMs": 2000000,
}
TEST_TABLE = "test-table"
GCP_CONN_ID = "google_cloud_default"


class TestBigQueryCreateEmptyTableOperator:
Expand Down Expand Up @@ -791,6 +792,7 @@ def test_execute(self, mock_hook, as_dict):
max_results = 100
selected_fields = "DATE"
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
Expand All @@ -799,8 +801,10 @@ def test_execute(self, mock_hook, as_dict):
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
as_dict=as_dict,
use_legacy_sql=False,
)
operator.execute(None)
mock_hook.assert_called_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, use_legacy_sql=False)
mock_hook.return_value.list_rows.assert_called_once_with(
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
Expand All @@ -818,12 +822,6 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
Asserts that a task is deferred and a BigQuerygetDataTrigger will be fired
when the BigQueryGetDataOperator is executed with deferrable=True.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

ti = create_task_instance_of_operator(
BigQueryGetDataOperator,
dag_id="dag_id",
Expand All @@ -833,6 +831,7 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
max_results=100,
selected_fields="value,name",
deferrable=True,
use_legacy_sql=False,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -851,12 +850,6 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired
when the BigQueryGetDataOperator is executed with deferrable=True.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

ti = create_task_instance_of_operator(
BigQueryGetDataOperator,
dag_id="dag_id",
Expand All @@ -866,6 +859,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -886,6 +880,7 @@ def test_bigquery_get_data_operator_execute_failure(self, as_dict):
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with pytest.raises(AirflowException):
Expand All @@ -904,6 +899,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with mock.patch.object(operator.log, "info") as mock_log_info:
Expand Down

0 comments on commit d1fe671

Please sign in to comment.