From 3c14753b03872b259ce2248eda92f7fb6f4d751b Mon Sep 17 00:00:00 2001 From: Avinash Holla Pandeshwar Date: Thu, 20 Jul 2023 23:52:45 +0530 Subject: [PATCH] Fix BigQueryGetDataOperator where project_id is not being respected in deferrable mode (#32488) * fixing BigQueryGetDataOperator to respect project_id as compute project. A new parameter table_project_id will be used for specifying table storage project. --- .../google/cloud/operators/bigquery.py | 34 +++++++++++++++---- .../google/cloud/operators/test_bigquery.py | 13 ++++--- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 70ab30d61a232..f5e5a9634f72d 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -810,7 +810,11 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator): :param dataset_id: The dataset ID of the requested table. (templated) :param table_id: The table ID of the requested table. (templated) - :param project_id: (Optional) The name of the project where the data + :param table_project_id: (Optional) The project ID of the requested table. + If None, it will be derived from the hook's project ID. (templated) + :param job_project_id: (Optional) Google Cloud Project where the job is running. + If None, it will be derived from the hook's project ID. (templated) + :param project_id: (Deprecated) (Optional) The name of the project where the data will be returned from. If None, it will be derived from the hook's project ID. (templated) :param max_results: The maximum number of records (rows) to be fetched from the table. (templated) @@ -837,6 +841,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator): template_fields: Sequence[str] = ( "dataset_id", "table_id", + "table_project_id", + "job_project_id", "project_id", "max_results", "selected_fields", @@ -849,6 +855,8 @@ def __init__( *, dataset_id: str, table_id: str, + table_project_id: str | None = None, + job_project_id: str | None = None, project_id: str | None = None, max_results: int = 100, selected_fields: str | None = None, @@ -863,8 +871,10 @@ def __init__( ) -> None: super().__init__(**kwargs) + self.table_project_id = table_project_id self.dataset_id = dataset_id self.table_id = table_id + self.job_project_id = job_project_id self.max_results = int(max_results) self.selected_fields = selected_fields self.gcp_conn_id = gcp_conn_id @@ -887,7 +897,7 @@ def _submit_job( return hook.insert_job( configuration=configuration, location=self.location, - project_id=hook.project_id, + project_id=self.job_project_id or hook.project_id, job_id=job_id, nowait=True, ) @@ -900,12 +910,22 @@ def generate_query(self, hook: BigQueryHook) -> str: else: query += "*" query += ( - f" from `{self.project_id or hook.project_id}.{self.dataset_id}" + f" from `{self.table_project_id or hook.project_id}.{self.dataset_id}" f".{self.table_id}` limit {self.max_results}" ) return query def execute(self, context: Context): + if self.project_id: + self.log.warning( + "The project_id parameter is deprecated, and will be removed in a future release." + " Please use table_project_id instead.", + ) + if not self.table_project_id: + self.table_project_id = self.project_id + else: + self.log.info("Ignoring project_id parameter, as table_project_id is found.") + hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -915,7 +935,7 @@ def execute(self, context: Context): if not self.deferrable: self.log.info( "Fetching Data from %s.%s.%s max results: %s", - self.project_id or hook.project_id, + self.table_project_id or hook.project_id, self.dataset_id, self.table_id, self.max_results, @@ -924,7 +944,7 @@ def execute(self, context: Context): schema: dict[str, list] = hook.get_schema( dataset_id=self.dataset_id, table_id=self.table_id, - project_id=self.project_id, + project_id=self.table_project_id or hook.project_id, ) if "fields" in schema: self.selected_fields = ",".join([field["name"] for field in schema["fields"]]) @@ -935,7 +955,7 @@ def execute(self, context: Context): max_results=self.max_results, selected_fields=self.selected_fields, location=self.location, - project_id=self.project_id, + project_id=self.table_project_id or hook.project_id, ) if isinstance(rows, RowIterator): @@ -961,7 +981,7 @@ def execute(self, context: Context): job_id=job.job_id, dataset_id=self.dataset_id, table_id=self.table_id, - project_id=hook.project_id, + project_id=self.job_project_id or hook.project_id, poll_interval=self.poll_interval, as_dict=self.as_dict, ), diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 226d1e7095954..4026b4ba45074 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -64,6 +64,7 @@ TEST_DATASET = "test-dataset" TEST_DATASET_LOCATION = "EU" TEST_GCP_PROJECT_ID = "test-project" +TEST_JOB_PROJECT_ID = "test-job-project" TEST_DELETE_CONTENTS = True TEST_TABLE_ID = "test-table-id" TEST_GCS_BUCKET = "test-bucket" @@ -804,7 +805,7 @@ def test_execute(self, mock_hook, as_dict): task_id=TASK_ID, dataset_id=TEST_DATASET, table_id=TEST_TABLE_ID, - project_id=TEST_GCP_PROJECT_ID, + table_project_id=TEST_GCP_PROJECT_ID, max_results=max_results, selected_fields=selected_fields, location=TEST_DATASET_LOCATION, @@ -823,13 +824,13 @@ def test_execute(self, mock_hook, as_dict): ) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") - def test_generate_query__with_project_id(self, mock_hook): + def test_generate_query__with_table_project_id(self, mock_hook): operator = BigQueryGetDataOperator( gcp_conn_id=GCP_CONN_ID, task_id=TASK_ID, dataset_id=TEST_DATASET, table_id=TEST_TABLE_ID, - project_id=TEST_GCP_PROJECT_ID, + table_project_id=TEST_GCP_PROJECT_ID, max_results=100, use_legacy_sql=False, ) @@ -839,7 +840,7 @@ def test_generate_query__with_project_id(self, mock_hook): ) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") - def test_generate_query__without_project_id(self, mock_hook): + def test_generate_query__without_table_project_id(self, mock_hook): hook_project_id = mock_hook.project_id operator = BigQueryGetDataOperator( gcp_conn_id=GCP_CONN_ID, @@ -868,6 +869,7 @@ def test_bigquery_get_data_operator_async_with_selected_fields( task_id="get_data_from_bq", dataset_id=TEST_DATASET, table_id=TEST_TABLE_ID, + job_project_id=TEST_JOB_PROJECT_ID, max_results=100, selected_fields="value,name", deferrable=True, @@ -896,6 +898,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields( task_id="get_data_from_bq", dataset_id=TEST_DATASET, table_id=TEST_TABLE_ID, + job_project_id=TEST_JOB_PROJECT_ID, max_results=100, deferrable=True, as_dict=as_dict, @@ -917,6 +920,7 @@ def test_bigquery_get_data_operator_execute_failure(self, as_dict): task_id="get_data_from_bq", dataset_id=TEST_DATASET, table_id="any", + job_project_id=TEST_JOB_PROJECT_ID, max_results=100, deferrable=True, as_dict=as_dict, @@ -936,6 +940,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict): task_id="get_data_from_bq", dataset_id=TEST_DATASET, table_id="any", + job_project_id=TEST_JOB_PROJECT_ID, max_results=100, deferrable=True, as_dict=as_dict,