Skip to content

Commit

Permalink
Fix BigQueryGetDataOperator where project_id is not being respected i…
Browse files Browse the repository at this point in the history
…n deferrable mode (#32488)

* fixing BigQueryGetDataOperator to respect project_id as compute project. A new parameter table_project_id will be used for specifying table storage project.
  • Loading branch information
avinashpandeshwar committed Jul 20, 2023
1 parent ac52482 commit 3c14753
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
34 changes: 27 additions & 7 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -810,7 +810,11 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param project_id: (Optional) The name of the project where the data
:param table_project_id: (Optional) The project ID of the requested table.
If None, it will be derived from the hook's project ID. (templated)
:param job_project_id: (Optional) Google Cloud Project where the job is running.
If None, it will be derived from the hook's project ID. (templated)
:param project_id: (Deprecated) (Optional) The name of the project where the data
will be returned from. If None, it will be derived from the hook's project ID. (templated)
:param max_results: The maximum number of records (rows) to be fetched
from the table. (templated)
Expand All @@ -837,6 +841,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
template_fields: Sequence[str] = (
"dataset_id",
"table_id",
"table_project_id",
"job_project_id",
"project_id",
"max_results",
"selected_fields",
Expand All @@ -849,6 +855,8 @@ def __init__(
*,
dataset_id: str,
table_id: str,
table_project_id: str | None = None,
job_project_id: str | None = None,
project_id: str | None = None,
max_results: int = 100,
selected_fields: str | None = None,
Expand All @@ -863,8 +871,10 @@ def __init__(
) -> None:
super().__init__(**kwargs)

self.table_project_id = table_project_id
self.dataset_id = dataset_id
self.table_id = table_id
self.job_project_id = job_project_id
self.max_results = int(max_results)
self.selected_fields = selected_fields
self.gcp_conn_id = gcp_conn_id
Expand All @@ -887,7 +897,7 @@ def _submit_job(
return hook.insert_job(
configuration=configuration,
location=self.location,
project_id=hook.project_id,
project_id=self.job_project_id or hook.project_id,
job_id=job_id,
nowait=True,
)
Expand All @@ -900,12 +910,22 @@ def generate_query(self, hook: BigQueryHook) -> str:
else:
query += "*"
query += (
f" from `{self.project_id or hook.project_id}.{self.dataset_id}"
f" from `{self.table_project_id or hook.project_id}.{self.dataset_id}"
f".{self.table_id}` limit {self.max_results}"
)
return query

def execute(self, context: Context):
if self.project_id:
self.log.warning(
"The project_id parameter is deprecated, and will be removed in a future release."
" Please use table_project_id instead.",
)
if not self.table_project_id:
self.table_project_id = self.project_id
else:
self.log.info("Ignoring project_id parameter, as table_project_id is found.")

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand All @@ -915,7 +935,7 @@ def execute(self, context: Context):
if not self.deferrable:
self.log.info(
"Fetching Data from %s.%s.%s max results: %s",
self.project_id or hook.project_id,
self.table_project_id or hook.project_id,
self.dataset_id,
self.table_id,
self.max_results,
Expand All @@ -924,7 +944,7 @@ def execute(self, context: Context):
schema: dict[str, list] = hook.get_schema(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
project_id=self.table_project_id or hook.project_id,
)
if "fields" in schema:
self.selected_fields = ",".join([field["name"] for field in schema["fields"]])
Expand All @@ -935,7 +955,7 @@ def execute(self, context: Context):
max_results=self.max_results,
selected_fields=self.selected_fields,
location=self.location,
project_id=self.project_id,
project_id=self.table_project_id or hook.project_id,
)

if isinstance(rows, RowIterator):
Expand All @@ -961,7 +981,7 @@ def execute(self, context: Context):
job_id=job.job_id,
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=hook.project_id,
project_id=self.job_project_id or hook.project_id,
poll_interval=self.poll_interval,
as_dict=self.as_dict,
),
Expand Down
13 changes: 9 additions & 4 deletions tests/providers/google/cloud/operators/test_bigquery.py
Expand Up @@ -64,6 +64,7 @@
TEST_DATASET = "test-dataset"
TEST_DATASET_LOCATION = "EU"
TEST_GCP_PROJECT_ID = "test-project"
TEST_JOB_PROJECT_ID = "test-job-project"
TEST_DELETE_CONTENTS = True
TEST_TABLE_ID = "test-table-id"
TEST_GCS_BUCKET = "test-bucket"
Expand Down Expand Up @@ -804,7 +805,7 @@ def test_execute(self, mock_hook, as_dict):
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
project_id=TEST_GCP_PROJECT_ID,
table_project_id=TEST_GCP_PROJECT_ID,
max_results=max_results,
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
Expand All @@ -823,13 +824,13 @@ def test_execute(self, mock_hook, as_dict):
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_generate_query__with_project_id(self, mock_hook):
def test_generate_query__with_table_project_id(self, mock_hook):
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
project_id=TEST_GCP_PROJECT_ID,
table_project_id=TEST_GCP_PROJECT_ID,
max_results=100,
use_legacy_sql=False,
)
Expand All @@ -839,7 +840,7 @@ def test_generate_query__with_project_id(self, mock_hook):
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_generate_query__without_project_id(self, mock_hook):
def test_generate_query__without_table_project_id(self, mock_hook):
hook_project_id = mock_hook.project_id
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -868,6 +869,7 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
selected_fields="value,name",
deferrable=True,
Expand Down Expand Up @@ -896,6 +898,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,
Expand All @@ -917,6 +920,7 @@ def test_bigquery_get_data_operator_execute_failure(self, as_dict):
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id="any",
job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,
Expand All @@ -936,6 +940,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id="any",
job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,
Expand Down

0 comments on commit 3c14753

Please sign in to comment.